diff --git a/doc/source/basics/assets.rst b/doc/source/basics/assets.rst index 5c3d22ec7..8e37ed83d 100644 --- a/doc/source/basics/assets.rst +++ b/doc/source/basics/assets.rst @@ -37,7 +37,7 @@ How to Customize Your Assets * Add the ``name``, ``dataset_family``, and ``data`` fields, which allows fairseq2 to find the corresponding dataset loader - * For more detailed information about ``dataset_family``, please refer to :doc:`Dataset Loaders ` + * For more detailed information about ``dataset_family``, please refer to :doc:`Dataset Loaders ` .. code-block:: yaml @@ -105,4 +105,4 @@ In fairseq2, a model card is accessed via :py:class:`fairseq2.assets.AssetCard`. See Also -------- -- :doc:`Dataset Loaders ` +- :doc:`Datasets ` diff --git a/doc/source/reference/api/fairseq2.datasets/index.rst b/doc/source/reference/api/fairseq2.datasets/index.rst index 8a95b4ba0..4cf73b9d5 100644 --- a/doc/source/reference/api/fairseq2.datasets/index.rst +++ b/doc/source/reference/api/fairseq2.datasets/index.rst @@ -4,13 +4,59 @@ fairseq2.datasets .. module:: fairseq2.datasets -This module contains dataset loaders. +=============== +Dataset Loaders +=============== -.. autoclasstree:: fairseq2.datasets - :full: - :zoom: +The dataset loader system in fairseq2 provides a flexible and extensible way to load different types of datasets. +The system uses the concept of dataset families to organize and manage different dataset formats. -.. toctree:: - :maxdepth: 1 +Dataset Family +-------------- - loader +A dataset family represents a specific format or structure of data that requires specialized loading logic. +Each dataset is associated with a family through the ``dataset_family`` field in its asset card. + +Built-in Dataset Families +^^^^^^^^^^^^^^^^^^^^^^^^^ + +fairseq2 includes several built-in dataset families: + +- ``generic_text``: For plain text datasets +- ``generic_parallel_text``: For parallel text/translation datasets +- ``generic_asr``: For automatic speech recognition datasets +- ``generic_speech``: For speech-only datasets +- ``generic_instruction``: For instruction-tuning datasets +- ``generic_preference_optimization``: For preference optimization datasets + +Example Asset Card +^^^^^^^^^^^^^^^^^^ + +.. code-block:: yaml + + name: librispeech_asr + dataset_family: generic_asr + tokenizer: "https://example.com/tokenizer.model" + tokenizer_family: char_tokenizer + +Usage Examples +-------------- + +Loading a Dataset Using Family +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from fairseq2.datasets import load_text_dataset + + # Load using dataset name (will look up asset card) + dataset = load_text_dataset("my_text_dataset") + + # Load using explicit asset card + card = AssetCard(name="custom_dataset", dataset_family="generic_text") + dataset = load_text_dataset(card) + +See Also +-------- + +- :doc:`Text Dataset ` diff --git a/doc/source/reference/api/fairseq2.datasets/loader.rst b/doc/source/reference/api/fairseq2.datasets/loader.rst deleted file mode 100644 index 121e47d1b..000000000 --- a/doc/source/reference/api/fairseq2.datasets/loader.rst +++ /dev/null @@ -1,110 +0,0 @@ -.. module:: fairseq2.datasets.loader - -=============== -Dataset Loaders -=============== - -.. autoclasstree:: fairseq2.datasets.loader - :full: - :zoom: - -The dataset loader system in fairseq2 provides a flexible and extensible way to load different types of datasets. -The system uses the concept of dataset families to organize and manage different dataset formats. - -Dataset Family --------------- - -A dataset family represents a specific format or structure of data that requires specialized loading logic. -Each dataset is associated with a family through the ``dataset_family`` field in its asset card. - -Built-in Dataset Families -^^^^^^^^^^^^^^^^^^^^^^^^^ - -fairseq2 includes several built-in dataset families: - -- ``generic_text``: For plain text datasets -- ``generic_parallel_text``: For parallel text/translation datasets -- ``generic_asr``: For automatic speech recognition datasets -- ``generic_speech``: For speech-only datasets -- ``generic_instruction``: For instruction-tuning datasets -- ``generic_preference_optimization``: For preference optimization datasets - -Example Asset Card -^^^^^^^^^^^^^^^^^^ - -.. code-block:: yaml - - name: librispeech_asr - dataset_family: generic_asr - tokenizer: "https://example.com/tokenizer.model" - tokenizer_family: char_tokenizer - -Core Components ---------------- - -DatasetLoader Protocol -^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: fairseq2.datasets.loader.DatasetLoader - :members: - :special-members: __call__ - -AbstractDatasetLoader -^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: fairseq2.datasets.loader.AbstractDatasetLoader - :members: - :show-inheritance: - -DelegatingDatasetLoader -^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: fairseq2.datasets.loader.DelegatingDatasetLoader - :members: - :show-inheritance: - -Utility Functions ------------------ - -.. autofunction:: fairseq2.datasets.loader.is_dataset_card - -.. autofunction:: fairseq2.datasets.loader.get_dataset_family - -Usage Examples --------------- - -1. Loading a Dataset Using Family -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: python - - from fairseq2.datasets import load_text_dataset - - # Load using dataset name (will look up asset card) - dataset = load_text_dataset("my_text_dataset") - - # Load using explicit asset card - card = AssetCard(name="custom_dataset", dataset_family="generic_text") - dataset = load_text_dataset(card) - -2. Registering a Custom Dataset Loader -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: python - - from fairseq2.datasets import DelegatingDatasetLoader - - # Create your custom dataset loader - class MyCustomDatasetLoader(AbstractDatasetLoader[MyDataset]): - def _load(self, path: Path, card: AssetCard) -> MyDataset: - return MyDataset.from_path(path) - - # Register with a family name - loader = MyCustomDatasetLoader() - load_dataset = DelegatingDatasetLoader() - load_dataset.register("my_custom_family", loader) - -See Also --------- - -- :doc:`Text Dataset ` diff --git a/src/fairseq2/__init__.py b/src/fairseq2/__init__.py index 2dd74f8cd..f39c1953b 100644 --- a/src/fairseq2/__init__.py +++ b/src/fairseq2/__init__.py @@ -12,7 +12,6 @@ # isort: split -import fairseq2.datasets import fairseq2.models # isort: split diff --git a/src/fairseq2/chatbots/llama.py b/src/fairseq2/chatbots/llama.py index 7b8452612..156ebfcd8 100644 --- a/src/fairseq2/chatbots/llama.py +++ b/src/fairseq2/chatbots/llama.py @@ -14,7 +14,8 @@ from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog, ChatMessage from fairseq2.chatbots.handler import ChatbotHandler -from fairseq2.data.text import LLaMA3Tokenizer, TextTokenEncoder, TextTokenizer +from fairseq2.data.text import TextTokenEncoder, TextTokenizer +from fairseq2.data.text.tokenizers.llama import LLaMA3Tokenizer from fairseq2.generation import SequenceGenerator from fairseq2.nn.utils.module import infer_device diff --git a/src/fairseq2/data/text/__init__.py b/src/fairseq2/data/text/__init__.py index 815d24a16..02ea1b5ac 100644 --- a/src/fairseq2/data/text/__init__.py +++ b/src/fairseq2/data/text/__init__.py @@ -11,82 +11,8 @@ from fairseq2.data.text.converters import StrToTensorConverter as StrToTensorConverter from fairseq2.data.text.text_reader import LineEnding as LineEnding from fairseq2.data.text.text_reader import read_text as read_text -from fairseq2.data.text.tokenizers.char_tokenizer import ( - CHAR_TOKENIZER_FAMILY as CHAR_TOKENIZER_FAMILY, -) -from fairseq2.data.text.tokenizers.handler import ( - StandardTextTokenizerHandler as StandardTextTokenizerHandler, -) -from fairseq2.data.text.tokenizers.handler import ( - TextTokenizerHandler as TextTokenizerHandler, -) -from fairseq2.data.text.tokenizers.handler import ( - TextTokenizerLoader as TextTokenizerLoader, -) -from fairseq2.data.text.tokenizers.handler import ( - TextTokenizerNotFoundError as TextTokenizerNotFoundError, -) -from fairseq2.data.text.tokenizers.handler import ( - get_text_tokenizer_family as get_text_tokenizer_family, -) -from fairseq2.data.text.tokenizers.llama import ( - LLAMA_TOKENIZER_FAMILY as LLAMA_TOKENIZER_FAMILY, -) -from fairseq2.data.text.tokenizers.llama import LLaMA3Tokenizer as LLaMA3Tokenizer -from fairseq2.data.text.tokenizers.mistral import ( - MISTRAL_TOKENIZER_FAMILY as MISTRAL_TOKENIZER_FAMILY, -) -from fairseq2.data.text.tokenizers.nllb import ( - NLLB_TOKENIZER_FAMILY as NLLB_TOKENIZER_FAMILY, -) -from fairseq2.data.text.tokenizers.nllb import NllbTokenizer as NllbTokenizer -from fairseq2.data.text.tokenizers.ref import ( - resolve_text_tokenizer_reference as resolve_text_tokenizer_reference, -) -from fairseq2.data.text.tokenizers.s2t_transformer import ( - S2T_TRANSFORMER_TOKENIZER_FAMILY as S2T_TRANSFORMER_TOKENIZER_FAMILY, -) -from fairseq2.data.text.tokenizers.s2t_transformer import ( - S2TTransformerTokenizer as S2TTransformerTokenizer, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - BasicSentencePieceTokenizer as BasicSentencePieceTokenizer, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - RawSentencePieceTokenizer as RawSentencePieceTokenizer, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - SentencePieceDecoder as SentencePieceDecoder, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - SentencePieceEncoder as SentencePieceEncoder, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - SentencePieceModel as SentencePieceModel, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - SentencePieceTokenizer as SentencePieceTokenizer, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - load_basic_sentencepiece as load_basic_sentencepiece, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - load_raw_sentencepiece as load_raw_sentencepiece, -) -from fairseq2.data.text.tokenizers.sentencepiece import ( - vocab_info_from_sentencepiece as vocab_info_from_sentencepiece, -) -from fairseq2.data.text.tokenizers.static import ( - load_text_tokenizer as load_text_tokenizer, -) -from fairseq2.data.text.tokenizers.tiktoken import TiktokenDecoder as TiktokenDecoder -from fairseq2.data.text.tokenizers.tiktoken import TiktokenEncoder as TiktokenEncoder -from fairseq2.data.text.tokenizers.tiktoken import ( - TiktokenTokenizer as TiktokenTokenizer, -) -from fairseq2.data.text.tokenizers.tokenizer import ( - AbstractTextTokenizer as AbstractTextTokenizer, -) -from fairseq2.data.text.tokenizers.tokenizer import TextTokenDecoder as TextTokenDecoder -from fairseq2.data.text.tokenizers.tokenizer import TextTokenEncoder as TextTokenEncoder -from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer as TextTokenizer +from fairseq2.data.text.tokenizers import AbstractTextTokenizer as AbstractTextTokenizer +from fairseq2.data.text.tokenizers import TextTokenDecoder as TextTokenDecoder +from fairseq2.data.text.tokenizers import TextTokenEncoder as TextTokenEncoder +from fairseq2.data.text.tokenizers import TextTokenizer as TextTokenizer +from fairseq2.data.text.tokenizers import load_text_tokenizer as load_text_tokenizer diff --git a/src/fairseq2/data/text/tokenizers/__init__.py b/src/fairseq2/data/text/tokenizers/__init__.py index e69de29bb..eb2f218a3 100644 --- a/src/fairseq2/data/text/tokenizers/__init__.py +++ b/src/fairseq2/data/text/tokenizers/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.data.text.tokenizers.handler import ( + StandardTextTokenizerHandler as StandardTextTokenizerHandler, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerHandler as TextTokenizerHandler, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerLoader as TextTokenizerLoader, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerNotFoundError as TextTokenizerNotFoundError, +) +from fairseq2.data.text.tokenizers.handler import ( + get_text_tokenizer_family as get_text_tokenizer_family, +) +from fairseq2.data.text.tokenizers.ref import ( + resolve_text_tokenizer_reference as resolve_text_tokenizer_reference, +) +from fairseq2.data.text.tokenizers.static import ( + load_text_tokenizer as load_text_tokenizer, +) +from fairseq2.data.text.tokenizers.tokenizer import ( + AbstractTextTokenizer as AbstractTextTokenizer, +) +from fairseq2.data.text.tokenizers.tokenizer import TextTokenDecoder as TextTokenDecoder +from fairseq2.data.text.tokenizers.tokenizer import TextTokenEncoder as TextTokenEncoder +from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer as TextTokenizer diff --git a/src/fairseq2/data/text/tokenizers/llama.py b/src/fairseq2/data/text/tokenizers/llama.py index cb5330c99..9e95aca35 100644 --- a/src/fairseq2/data/text/tokenizers/llama.py +++ b/src/fairseq2/data/text/tokenizers/llama.py @@ -91,9 +91,12 @@ def create_encoder( case "prompt_response": prefix_tokens = [] suffix_tokens = [self._eos_token] + case "as_is": + prefix_tokens = [] + suffix_tokens = [] case _: raise ValueError( - f"`mode` must be 'default' or 'prompt', but is '{mode}' instead." + f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response, as_is" ) return TiktokenEncoder( diff --git a/src/fairseq2/data/text/tokenizers/sentencepiece.py b/src/fairseq2/data/text/tokenizers/sentencepiece.py index 6b63e1d23..6857dbd02 100644 --- a/src/fairseq2/data/text/tokenizers/sentencepiece.py +++ b/src/fairseq2/data/text/tokenizers/sentencepiece.py @@ -218,7 +218,7 @@ def create_encoder( suffix_tokens = [""] case _: raise ValueError( - f"`mode` must be 'default' or 'prompt', but is '{mode}' instead." + f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response" ) return SentencePieceEncoder( diff --git a/src/fairseq2/datasets/__init__.py b/src/fairseq2/datasets/__init__.py index a72b31049..bedc5bd3c 100644 --- a/src/fairseq2/datasets/__init__.py +++ b/src/fairseq2/datasets/__init__.py @@ -11,17 +11,12 @@ from fairseq2.datasets.batching import StaticBatching as StaticBatching from fairseq2.datasets.data_reader import DataPipelineReader as DataPipelineReader from fairseq2.datasets.data_reader import DataReader as DataReader +from fairseq2.datasets.data_reader import SyncMode as SyncMode +from fairseq2.datasets.error import DataReadError as DataReadError from fairseq2.datasets.error import DatasetError as DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader as AbstractDatasetLoader -from fairseq2.datasets.loader import DatasetLoader as DatasetLoader -from fairseq2.datasets.loader import DelegatingDatasetLoader as DelegatingDatasetLoader -from fairseq2.datasets.loader import get_dataset_family as get_dataset_family -from fairseq2.datasets.loader import is_dataset_card as is_dataset_card - -# isort: split - -import fairseq2.datasets.asr -import fairseq2.datasets.instruction -import fairseq2.datasets.parallel_text -import fairseq2.datasets.speech -import fairseq2.datasets.text +from fairseq2.datasets.handler import DatasetHandler as DatasetHandler +from fairseq2.datasets.handler import DatasetLoader as DatasetLoader +from fairseq2.datasets.handler import DatasetNotFoundError as DatasetNotFoundError +from fairseq2.datasets.handler import StandardDatasetHandler as StandardDatasetHandler +from fairseq2.datasets.handler import get_dataset_family as get_dataset_family +from fairseq2.datasets.static import load_dataset as load_dataset diff --git a/src/fairseq2/datasets/asr.py b/src/fairseq2/datasets/asr.py index 0b8f019af..be62244bf 100644 --- a/src/fairseq2/datasets/asr.py +++ b/src/fairseq2/datasets/asr.py @@ -8,14 +8,14 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Literal, cast, final +from typing import Any, Final, cast, final import torch from torch import Tensor from torch.nn.functional import layer_norm from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -29,9 +29,10 @@ from fairseq2.data.audio import AudioDecoder from fairseq2.data.text import StrSplitter, TextTokenizer, read_text from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.error import DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.error import DatasetError, SplitNotFoundError +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.nn.padding import get_seqs_and_padding_mask @@ -57,12 +58,11 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, - **extras: Any, ) -> DataReader[Seq2SeqBatch]: """Create a dataset reader. @@ -115,8 +115,6 @@ def create_reader( The number of batches to prefetch in background. :param seed: The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. """ @abstractmethod @@ -124,47 +122,51 @@ def splits(self) -> set[str]: """Return the set of splits.""" -load_asr_dataset = DelegatingDatasetLoader[AsrDataset]() - - # TODO: FIX, INFER npc = 10 +GENERIC_ASR_DATASET_FAMILY: Final = "generic_asr" + + # TODO: Work in progress! @final class GenericAsrDataset(AsrDataset): """Represents a generic manifest-based ASR dataset.""" + _name: str _manifest_dir: Path _splits: set[str] - def __init__(self, manifest_dir: Path, splits: set[str]) -> None: + def __init__(self, name: str, manifest_dir: Path, splits: set[str]) -> None: """ :param manifest_dir: The directory under which the manifest files resides. :param splits: The available splits. """ + self._name = name self._manifest_dir = manifest_dir self._splits = splits - @classmethod - def from_path(cls, path: Path) -> GenericAsrDataset: - """Load a :class:`GenericAsrDataset` from ``path``.""" + @staticmethod + def from_path(path: Path, name: str | None = None) -> GenericAsrDataset: + if name is None: + name = f"path:{path.name}" + path = path.expanduser().resolve() if not path.is_dir(): - return GenericAsrDataset(manifest_dir=path.parent, splits={path.stem}) + return GenericAsrDataset(name, manifest_dir=path.parent, splits={path.stem}) try: splits = {f.stem for f in path.glob("*.tsv")} except OSError as ex: - raise RuntimeError( - "The splits cannot be determined. See nested exception for details." + raise DatasetError( + f"The splits under the '{path}' directory cannot be determined. See the nested exception for details." ) from ex - return GenericAsrDataset(path, splits) + return GenericAsrDataset(name, path, splits) @override def create_reader( @@ -182,13 +184,12 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, cached_fd_count: int = 1000, - **extras: Any, ) -> DataPipelineReader[Seq2SeqBatch]: """ :param cached_fd_count: @@ -196,9 +197,7 @@ def create_reader( audio files. """ if split not in self._splits: - raise ValueError( - f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits))}" - ) + raise SplitNotFoundError(self._name, split, self._splits) audio_dir = self._retrieve_data_directory(split) @@ -234,7 +233,7 @@ def create_reader( ) elif isinstance(batching, StaticBatching): # Filter out out-of-range audios. - def skip(example: dict[str, Any]) -> bool: + def skip(example: dict[str, object]) -> bool: audio_len = cast(int, example["audio_size"]) return audio_len >= min_audio_len and audio_len <= max_audio_len @@ -244,7 +243,7 @@ def skip(example: dict[str, Any]) -> bool: # Bucket `batch_size` examples. builder.bucket(batching.batch_size, drop_remainder=drop_remainder) else: - raise RuntimeError(f"`{batching}` is not supported.") + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. if batch_shuffle_window != 1: @@ -318,6 +317,7 @@ def to_batch(example: dict[str, Any]) -> Seq2SeqBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[Seq2SeqBatch]( + self._name, pipeline, gang, num_accumulate=num_accumulate, @@ -334,14 +334,15 @@ def _retrieve_data_directory(self, split: str) -> Path: line = fp.readline().rstrip() except OSError as ex: raise DatasetError( - f"{manifest_file} cannot be read. See nested exception for details." + self._name, + f"The {manifest_file} manifest file cannot be read. See the nested exception for details.", ) from ex try: return Path(line) except ValueError: raise DatasetError( - f"The first line of {manifest_file} must point to a data directory." + f"The first line of the '{manifest_file}' manifest file must point to a data directory." ) from None def _read_manifest(self, split: str) -> DataPipelineBuilder: @@ -381,18 +382,7 @@ def splits(self) -> set[str]: return self._splits -@final -class GenericAsrDatasetLoader(AbstractDatasetLoader[GenericAsrDataset]): - @override - def _load(self, path: Path, card: AssetCard) -> GenericAsrDataset: - try: - return GenericAsrDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_asr_dataset = GenericAsrDatasetLoader() - -load_asr_dataset.register("generic_asr", load_generic_asr_dataset) +def load_asr_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> AsrDataset: + return load_dataset(name_or_card, AsrDataset, force=force) diff --git a/src/fairseq2/datasets/data_reader.py b/src/fairseq2/datasets/data_reader.py index dfac06042..478cb35ea 100644 --- a/src/fairseq2/datasets/data_reader.py +++ b/src/fairseq2/datasets/data_reader.py @@ -8,19 +8,14 @@ from abc import ABC, abstractmethod from collections.abc import Iterator, Mapping -from typing import Any, Literal, TypeVar, final +from typing import Literal, TypeAlias, TypeVar, final from typing_extensions import Self, override -from fairseq2.data import DataPipeline +from fairseq2.data import DataPipeline, DataPipelineError +from fairseq2.datasets.error import DataReadError from fairseq2.datasets.utils import _min_num_batches, _sum_num_batches -from fairseq2.gang import Gang -from fairseq2.logging import get_log_writer - -log = get_log_writer(__name__) - - -BatchT = TypeVar("BatchT") +from fairseq2.gang import Gang, GangError BatchT_co = TypeVar("BatchT_co", covariant=True) @@ -41,11 +36,11 @@ def reset(self) -> None: """Reset state and move back to the first batch.""" @abstractmethod - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> dict[str, object]: ... @abstractmethod - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: ... @property @@ -54,10 +49,17 @@ def num_accumulate(self) -> int: """The number of batches accumulated in each iteration.""" +SyncMode: TypeAlias = Literal["until_first", "until_last"] + + +BatchT = TypeVar("BatchT") + + @final class DataPipelineReader(DataReader[BatchT]): """Reads batches of examples from a dataset using a :class:`DataPipeline`.""" + _name: str _pipeline: DataPipeline _pipeline_iter: Iterator[BatchT] _gang: Gang @@ -67,15 +69,17 @@ class DataPipelineReader(DataReader[BatchT]): def __init__( self, + name: str, pipeline: DataPipeline, gang: Gang, *, num_accumulate: int = 1, drop_remainder: bool = True, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", ) -> None: """ + :param name: The name of the dataset. :param pipeline: The data pipeline to iterate over. :param gang: @@ -97,6 +101,7 @@ def __init__( reaches end of data; ranks that have already reached their end of data will return an empty list of batches. """ + self._name = name self._pipeline = pipeline self._pipeline_iter = iter(pipeline) self._gang = gang @@ -122,6 +127,10 @@ def __next__(self) -> list[BatchT]: batch = next(self._pipeline_iter) except StopIteration: break + except DataPipelineError as ex: + raise DataReadError( + self._name, "The data pipeline has failed to read the next batch. See the nested exception for details." # fmt: skip + ) from ex batches.append(batch) @@ -133,13 +142,18 @@ def __next__(self) -> list[BatchT]: local_num_batches = len(batches) if self._sync_batches and self._gang.size > 1: - if self._sync_until_last: - num_batches = _sum_num_batches(local_num_batches, self._gang) - else: - num_batches = _min_num_batches(local_num_batches, self._gang, log) - - if num_batches != local_num_batches: - batches = batches[:num_batches] + try: + if self._sync_until_last: + num_batches = _sum_num_batches(local_num_batches, self._gang) + else: + num_batches = _min_num_batches(local_num_batches, self._gang) + + if num_batches != local_num_batches: + batches = batches[:num_batches] + except GangError as ex: + raise DataReadError( + self._name, "The batch synchronization of the gang processes has failed. See the nested exception for details." # fmt: skip + ) from ex else: num_batches = local_num_batches @@ -157,11 +171,11 @@ def reset(self) -> None: self._pipeline.reset() @override - def state_dict(self) -> dict[str, Any]: + def state_dict(self) -> dict[str, object]: return self._pipeline.state_dict() @override - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: self._eod = False self._pipeline.load_state_dict(state_dict) diff --git a/src/fairseq2/datasets/error.py b/src/fairseq2/datasets/error.py index 5a0b94f1d..71e312ec3 100644 --- a/src/fairseq2/datasets/error.py +++ b/src/fairseq2/datasets/error.py @@ -6,6 +6,29 @@ from __future__ import annotations +from collections.abc import Set -class DatasetError(RuntimeError): - """Raised when a dataset can't be read.""" + +class DatasetError(Exception): + pass + + +class DataReadError(Exception): + pass + + +class SplitNotFoundError(LookupError): + name: str + split: str + available_splits: Set[str] + + def __init__(self, name: str, split: str, available_splits: Set[str]) -> None: + s = ", ".join(sorted(available_splits)) + + super().__init__( + f"`split` must be one of the following splits, but is '{split}' instead: {s}" + ) + + self.name = name + self.split = split + self.available_splits = available_splits diff --git a/src/fairseq2/datasets/handler.py b/src/fairseq2/datasets/handler.py new file mode 100644 index 000000000..3ec2f7334 --- /dev/null +++ b/src/fairseq2/datasets/handler.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Protocol, final + +from typing_extensions import override + +from fairseq2.assets import AssetCard, AssetDownloadManager, AssetError +from fairseq2.datasets.error import DatasetError + + +class DatasetHandler(ABC): + @abstractmethod + def load(self, card: AssetCard, *, force: bool) -> object: + ... + + @abstractmethod + def load_from_path(self, path: Path) -> object: + ... + + @property + @abstractmethod + def kls(self) -> type: + ... + + +class DatasetNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known dataset.") + + self.name = name + + +class DatasetLoader(Protocol): + def __call__(self, path: Path, name: str | None) -> object: + ... + + +@final +class StandardDatasetHandler(DatasetHandler): + _kls: type + _loader: DatasetLoader + _asset_download_manager: AssetDownloadManager + + def __init__( + self, + kls: type, + loader: DatasetLoader, + asset_download_manager: AssetDownloadManager, + ) -> None: + self._kls = kls + self._loader = loader + self._asset_download_manager = asset_download_manager + + @override + def load(self, card: AssetCard, *, force: bool) -> object: + dataset_uri = card.field("data").as_uri() + + path = self._asset_download_manager.download_dataset( + dataset_uri, card.name, force=force + ) + + try: + return self._loader(path, card.name) + except DatasetError as ex: + raise AssetError( + f"The constructor of the '{card.name}' dataset has raised an error. See the nested exception for details." + ) from ex + + @override + def load_from_path(self, path: Path) -> object: + return self._loader(path, name=None) + + @override + @property + def kls(self) -> type: + return self._kls + + +def get_dataset_family(card: AssetCard) -> str: + return card.field("dataset_family").as_(str) # type: ignore[no-any-return] diff --git a/src/fairseq2/datasets/instruction.py b/src/fairseq2/datasets/instruction.py index aaf204c05..2822dfb05 100644 --- a/src/fairseq2/datasets/instruction.py +++ b/src/fairseq2/datasets/instruction.py @@ -10,12 +10,12 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal, NoReturn, cast, final +from typing import Any, Final, cast, final import torch from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -27,9 +27,11 @@ ) from fairseq2.data.text import TextTokenizer from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.error import DatasetError, SplitNotFoundError +from fairseq2.datasets.static import load_dataset from fairseq2.datasets.utils import _load_files_and_weights +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.padding import get_seqs_and_padding_mask @@ -52,14 +54,13 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, - src_encode_mode: str = "prompt", - tgt_encode_mode: str = "prompt_response", + source_encode_mode: str = "prompt", + target_encode_mode: str = "prompt_response", seed: int = 2, - **extras: Any, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -106,14 +107,12 @@ def create_reader( used with gradient accumulation during training. :param num_prefetch: The number of batches to prefetch in background. + :param source_encode_mode: + The mode to encode the source text. + :param target_encode_mode: + The mode to encode the target text. :param seed: The seed to initialize the random number generators used internally. - :param src_encode_mode: - The mode to encode the prompt - :param tgt_encode_mode: - The mode to encode the target - :param extras: - The extra parameters specific to the dataset implementation. """ @abstractmethod @@ -127,10 +126,9 @@ def create_prompt_reader( *, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", num_prefetch: int = 1, - src_encode_mode: str = "prompt", - **extras: Any, + source_encode_mode: str = "prompt", ) -> DataPipelineReader[SequenceBatch]: """Create a dataset reader for evaluation. @@ -156,10 +154,8 @@ def create_prompt_reader( same number of batches (e.g. during training). :param num_prefetch: The number of batches to prefetch in background. - :param src_encode_mode: - The mode to encode the prompt - :param extras: - The extra parameters specific to the dataset implementation. + :param source_encode_mode: + The mode to encode the source text. """ @abstractmethod @@ -167,22 +163,23 @@ def splits(self) -> set[str]: """Return the set of splits.""" -load_instruction_dataset = DelegatingDatasetLoader[InstructionDataset]() - - # TODO: FIX, INFER npc = 10 +GENERIC_INSTRUCTION_DATASET_FAMILY: Final = "generic_instruction" + + # TODO: Work in progress! @final class GenericInstructionDataset(InstructionDataset): """Represents a generic JSONL instruction dataset.""" + _name: str _splits: dict[str, tuple[Sequence[Path], Sequence[float]]] def __init__( - self, splits: dict[str, tuple[Sequence[Path], Sequence[float]]] + self, name: str, splits: dict[str, tuple[Sequence[Path], Sequence[float]]] ) -> None: """ :param files: @@ -190,6 +187,8 @@ def __init__( :param weights: The weight of each file in ``files``. """ + self._name = name + for split, (files, weights) in splits.items(): if len(files) != len(weights): raise ValueError( @@ -198,23 +197,32 @@ def __init__( self._splits = splits - @classmethod - def from_path(cls, path: Path) -> GenericInstructionDataset: - """Load a :class:`InstructionDataset` from ``path``.""" + @staticmethod + def from_path(path: Path, name: str | None = None) -> GenericInstructionDataset: + if name is None: + name = f"path:{path.name}" + splits: dict[str, tuple[Sequence[Path], Sequence[float]]] = {} - for child_path in path.iterdir(): - if child_path.is_dir(): - files, weights = _load_files_and_weights(child_path) + if path.is_dir(): + try: + child_dirs = [p for p in path.iterdir() if p.is_dir()] + except OSError as ex: + raise DatasetError( + name, f"The files under the '{path}' directory cannot be retrieved. See the nested exception for details." # fmt: skip + ) from ex + + for child_dir in child_dirs: + files, weights = _load_files_and_weights(name, child_dir) - splits[child_path.name] = (files, weights) + splits[child_dir.name] = (files, weights) if not splits: - files, weights = _load_files_and_weights(path) + files, weights = _load_files_and_weights(name, path) splits["default"] = (files, weights) - return GenericInstructionDataset(splits) + return GenericInstructionDataset(name, splits) @override def create_reader( @@ -230,19 +238,19 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, + source_encode_mode: str = "prompt", + target_encode_mode: str = "prompt_response", seed: int = 2, - src_encode_mode: str = "prompt", - tgt_encode_mode: str = "prompt_response", - **extras: Any, ) -> DataPipelineReader[SequenceBatch]: - try: - files, weights = self._splits[split] - except KeyError: - self._raise_split_error(split) + files_weights = self._splits.get(split) + if files_weights is None: + raise SplitNotFoundError(self._name, split, self._splits.keys()) + + files, weights = files_weights if len(files) == 1: builder = self._read_jsonl(files[0], tokenizer) @@ -272,22 +280,22 @@ def create_reader( seed += gang.rank - # Encode prompt and target texts. - prompt_encoder = tokenizer.create_encoder(mode=src_encode_mode) - target_encoder = tokenizer.create_encoder(mode=tgt_encode_mode) + # Encode source and target texts. + source_encoder = tokenizer.create_encoder(mode=source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=target_encode_mode) - builder.map(prompt_encoder, selector="src", num_parallel_calls=npc) + builder.map(source_encoder, selector="src", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt", num_parallel_calls=npc) def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: id_ = example.get("id") - prompt_indices = example["src"] + source_indices = example["src"] target_indices = example["tgt"] - indices = torch.cat([prompt_indices, target_indices]) + indices = torch.cat([source_indices, target_indices]) - target_mask = torch.arange(len(indices)) >= len(prompt_indices) + target_mask = torch.arange(len(indices)) >= len(source_indices) return {"id": id_, "indices": indices, "target_mask": target_mask} @@ -315,7 +323,7 @@ def skip(example: dict[str, Any]) -> bool: # Bucket `batch_size` examples. builder.bucket(batching.batch_size, drop_remainder=drop_remainder) else: - raise RuntimeError(f"`{batching}` is not supported.") + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. if batch_shuffle_window != 1: @@ -352,6 +360,7 @@ def to_batch(example: dict[str, Any]) -> SequenceBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[SequenceBatch]( + self._name, pipeline, gang, num_accumulate=num_accumulate, @@ -371,15 +380,14 @@ def create_prompt_reader( *, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", num_prefetch: int = 1, - src_encode_mode: str = "prompt", - **extras: Any, + source_encode_mode: str = "prompt", ) -> DataPipelineReader[SequenceBatch]: try: files, weights = self._splits[split] except KeyError: - self._raise_split_error(split) + raise SplitNotFoundError(self._name, split, self._splits.keys()) from None if len(files) == 1: builder = self._read_jsonl(files[0], tokenizer) @@ -396,17 +404,17 @@ def create_prompt_reader( # Shard builder.shard(gang.rank, gang.size, allow_uneven=True) - # Encode prompt texts. - text_encoder = tokenizer.create_encoder(mode=src_encode_mode) + # Encode source texts. + text_encoder = tokenizer.create_encoder(mode=source_encode_mode) def encode(example: dict[str, Any]) -> dict[str, Any]: id_ = example.get("id") - prompt = example["src"] + source = example["src"] - indices = text_encoder(prompt) + indices = text_encoder(source) - return {"id": id_, "prompt": prompt, "indices": indices} + return {"id": id_, "prompt": source, "indices": indices} builder.map(encode, num_parallel_calls=npc) @@ -438,6 +446,7 @@ def to_batch(example: dict[str, Any]) -> SequenceBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[SequenceBatch]( + self._name, pipeline, gang, drop_remainder=drop_remainder, @@ -455,30 +464,12 @@ def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuild return read_sequence(lines).map(json.loads, num_parallel_calls=npc) - def _raise_split_error(self, split: str) -> NoReturn: - raise ValueError( - f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits.keys()))}" - ) from None - @override def splits(self) -> set[str]: return set(self._splits.keys()) -@final -class GenericInstructionDatasetLoader(AbstractDatasetLoader[GenericInstructionDataset]): - @override - def _load(self, path: Path, card: AssetCard) -> GenericInstructionDataset: - try: - return GenericInstructionDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_instruction_dataset = GenericInstructionDatasetLoader() - -load_instruction_dataset.register( - "generic_instruction", load_generic_instruction_dataset -) +def load_instruction_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> InstructionDataset: + return load_dataset(name_or_card, InstructionDataset, force=force) diff --git a/src/fairseq2/datasets/loader.py b/src/fairseq2/datasets/loader.py deleted file mode 100644 index 378389314..000000000 --- a/src/fairseq2/datasets/loader.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Protocol, TypeVar, cast, final - -from fairseq2.assets import ( - AssetCard, - AssetCardError, - AssetDownloadManager, - AssetError, - AssetStore, - InProcAssetDownloadManager, - default_asset_store, -) - -DatasetT = TypeVar("DatasetT") - -DatasetT_co = TypeVar("DatasetT_co", covariant=True) - - -class DatasetLoader(Protocol[DatasetT_co]): - """Loads datasets of type ``DatasetT```.""" - - def __call__( - self, - dataset_name_or_card: str | AssetCard, - *, - force: bool = False, - progress: bool = True, - ) -> DatasetT_co: - """ - :param dataset_name_or_card: - The name or the asset card of the dataset to load. - :param force: - If ``True``, downloads the dataset even if it is already in cache. - :param progress: - If ``True``, displays a progress bar to stderr. - """ - - -class AbstractDatasetLoader(ABC, DatasetLoader[DatasetT]): - """Provides a skeletal implementation of :class:`DatasetLoader`.""" - - _asset_store: AssetStore - _download_manager: AssetDownloadManager - - def __init__( - self, - *, - asset_store: AssetStore | None = None, - download_manager: AssetDownloadManager | None = None, - ) -> None: - """ - :param asset_store: - The asset store where to check for available datasets. If ``None``, - the default asset store will be used. - :param download_manager: - The download manager. If ``None``, the default download manager will - be used. - """ - self._asset_store = asset_store or default_asset_store - self._download_manager = download_manager or InProcAssetDownloadManager() - - @final - def __call__( - self, - dataset_name_or_card: str | AssetCard, - *, - force: bool = False, - progress: bool = True, - ) -> DatasetT: - if isinstance(dataset_name_or_card, AssetCard): - card = dataset_name_or_card - else: - card = self._asset_store.retrieve_card(dataset_name_or_card) - - dataset_uri = card.field("data").as_uri() - - try: - path = self._download_manager.download_dataset( - dataset_uri, card.name, force=force, progress=progress - ) - except ValueError as ex: - raise AssetCardError( - card.name, f"The value of the field 'data' of the asset card '{card.name}' must be a URI. See nested exception for details." # fmt: skip - ) from ex - - try: - return self._load(path, card) - except ValueError as ex: - raise AssetError( - f"The {card.name} dataset cannot be loaded. See nested exception for details." - ) from ex - - @abstractmethod - def _load(self, path: Path, card: AssetCard) -> DatasetT: - """ - :param path: - The path to the dataset. - :param card: - The asset card of the dataset. - """ - - -@final -class DelegatingDatasetLoader(DatasetLoader[DatasetT]): - """Loads datasets of type ``DatasetT`` using registered loaders.""" - - _asset_store: AssetStore - _loaders: dict[str, DatasetLoader[DatasetT]] - - def __init__(self, *, asset_store: AssetStore | None = None) -> None: - """ - :param asset_store: - The asset store where to check for available datasets. If ``None``, - the default asset store will be used. - """ - self._asset_store = asset_store or default_asset_store - - self._loaders = {} - - def __call__( - self, - dataset_name_or_card: str | AssetCard, - *, - force: bool = False, - progress: bool = True, - ) -> DatasetT: - if isinstance(dataset_name_or_card, AssetCard): - card = dataset_name_or_card - else: - card = self._asset_store.retrieve_card(dataset_name_or_card) - - family = card.field("dataset_family").as_(str) - - try: - loader = self._loaders[family] - except KeyError: - raise AssetError( - f"The value of the field 'dataset_family' of the asset card '{card.name}' must be a supported dataset family, but is '{family}' instead." - ) from None - - return loader(card, force=force, progress=progress) - - def register(self, family: str, loader: DatasetLoader[DatasetT]) -> None: - """Register a dataset loader to use with this loader. - - :param family: - The dataset type. If the 'dataset_family' field of an asset card - matches this value, the specified ``loader`` will be used. - :param loader: - The dataset loader. - """ - if family in self._loaders: - raise ValueError( - f"`family` must be a unique dataset family name, but '{family}' is already registered." - ) - - self._loaders[family] = loader - - def supports(self, dataset_name_or_card: str | AssetCard) -> bool: - """Return ``True`` if the specified dataset has a registered loader.""" - if isinstance(dataset_name_or_card, AssetCard): - card = dataset_name_or_card - else: - card = self._asset_store.retrieve_card(dataset_name_or_card) - - family = card.field("dataset_family").as_(str) - - return family in self._loaders - - -def is_dataset_card(card: AssetCard) -> bool: - """Return ``True`` if ``card`` specifies a dataset.""" - return card.field("dataset_family").exists() - - -def get_dataset_family(card: AssetCard) -> str: - """Return the dataset family name contained in ``card``.""" - return cast(str, card.field("dataset_family").as_(str)) diff --git a/src/fairseq2/datasets/parallel_text.py b/src/fairseq2/datasets/parallel_text.py index 27a7c1835..0a0d9b349 100644 --- a/src/fairseq2/datasets/parallel_text.py +++ b/src/fairseq2/datasets/parallel_text.py @@ -10,11 +10,11 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Any, Literal, NoReturn, cast, final +from typing import Any, Final, cast, final from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( Collater, DataPipeline, @@ -24,9 +24,10 @@ ) from fairseq2.data.text import TextTokenizer, read_text from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.error import DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.error import DatasetError, SplitNotFoundError +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.nn.padding import get_seqs_and_padding_mask @@ -74,12 +75,11 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, - **extras: Any, ) -> DataReader[Seq2SeqBatch]: """Create a dataset reader. @@ -132,8 +132,6 @@ def create_reader( The number of batches to prefetch in background. :param seed: The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. """ @abstractmethod @@ -145,23 +143,24 @@ def directions(self, split: str) -> list[Direction]: """Return the directions included ``split``.""" -load_parallel_text_dataset = DelegatingDatasetLoader[ParallelTextDataset]() - - # TODO: FIX, INFER npc = 10 +GENERIC_PARALLEL_TEXT_DATASET_FAMILY: Final = "generic_parallel_text" + + @final class GenericParallelTextDataset(ParallelTextDataset): """Represents a generic file-based parallel text dataset.""" + _name: str _data_dir: Path _splits: dict[str, tuple[list[Direction], list[float]]] def __init__( self, - *, + name: str, data_dir: Path, splits: dict[str, tuple[list[Direction], list[float]]], ) -> None: @@ -172,6 +171,8 @@ def __init__( :param splits: The splits with their directions and their weights. """ + self._name = name + for split, (directions, weights) in splits.items(): if len(directions) != len(weights): raise ValueError( @@ -182,18 +183,24 @@ def __init__( self._splits = splits @classmethod - def from_path(cls, path: Path) -> GenericParallelTextDataset: - """Load a :class:`GenericParallelTextDataset` from ``path``.""" + def from_path( + cls, path: Path, name: str | None = None + ) -> GenericParallelTextDataset: + if name is None: + name = f"path:{path.name}" + path = path.expanduser().resolve() if not path.is_dir(): - raise ValueError("`path` must be a directory with a MANIFEST file.") + raise DatasetError( + name, f"The '{path}' path is expected to be a directory with a MANIFEST file." # fmt: skip + ) try: split_names = [d.name for d in path.iterdir() if d.is_dir()] except OSError as ex: - raise RuntimeError( - "The splits cannot be determined. See nested exception for details." + raise DatasetError( + name, f"The splits under the '{path}' directory cannot be determined. See the nested exception for details." # fmt: skip ) from ex splits = {} @@ -205,8 +212,8 @@ def from_path(cls, path: Path) -> GenericParallelTextDataset: with manifest_file.open() as fp: content = list(fp) except OSError as ex: - raise RuntimeError( - f"{manifest_file} cannot be read. See nested exception for details." + raise DatasetError( + name, f"The '{manifest_file}' file cannot be read. See the nested exception for details." # fmt: skip ) from ex # Sort the directions in alphabetical order. @@ -220,38 +227,38 @@ def from_path(cls, path: Path) -> GenericParallelTextDataset: # its weight (e.g. number of examples) in the split. for idx, line in enumerate(content): - def raise_error() -> NoReturn: - raise DatasetError( - f"Each line in {manifest_file} must represent a valid direction and a weight, but line {idx} is '{line}' instead." - ) from None + def error() -> DatasetError: + return DatasetError( + name, f"Each line in the '{manifest_file}' manifest file must represent a valid direction and a weight, but line {idx} is '{line}' instead." # fmt: skip + ) fields = line.rstrip().split("\t") if len(fields) != 2: - raise_error() + raise error() try: direction = cls._parse_direction(fields[0]) except ValueError: - raise_error() + raise error() from None directions.append(direction) try: weight = float(fields[1].strip()) except ValueError: - raise_error() + raise error() from None weights.append(weight) splits[split] = (directions, weights) - return GenericParallelTextDataset(data_dir=path, splits=splits) + return GenericParallelTextDataset(name, data_dir=path, splits=splits) @staticmethod def _parse_direction(s: str) -> Direction: - def raise_error() -> NoReturn: - raise ValueError( + def value_error() -> ValueError: + return ValueError( f"`s` must represent a valid direction, but is '{s}' instead." ) @@ -262,12 +269,12 @@ def raise_error() -> NoReturn: elif len(parts) == 2: origin, lang_pair = parts else: - raise_error() + raise value_error() parts = lang_pair.split("-") if len(parts) != 2: - raise_error() + raise value_error() source_lang, target_lang = parts @@ -275,7 +282,7 @@ def raise_error() -> NoReturn: target_lang = target_lang.strip() if not source_lang or not target_lang: - raise_error() + raise value_error() return Direction(source_lang, target_lang, origin) @@ -295,17 +302,17 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, - **extras: Any, ) -> DataPipelineReader[Seq2SeqBatch]: - try: - directions, weights = self._splits[split] - except KeyError: - self._raise_split_error(split) + directions_weights = self._splits.get(split) + if directions_weights is None: + raise SplitNotFoundError(self._name, split, self._splits.keys()) + + directions, weights = directions_weights # Determine the directions to read. if direction is not None: @@ -409,7 +416,7 @@ def skip(example: dict[str, Any]) -> bool: # Bucket `batch_size` examples. builder.bucket(batching.batch_size, drop_remainder=drop_remainder) else: - raise RuntimeError(f"`{batching}` is not supported.") + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. if batch_shuffle_window != 1: @@ -434,6 +441,7 @@ def skip(example: dict[str, Any]) -> bool: pipeline = builder.map(f).and_return() return DataPipelineReader[Seq2SeqBatch]( + self._name, pipeline, gang, num_accumulate=num_accumulate, @@ -460,12 +468,12 @@ def _read_direction(self, split: str, direction: Direction) -> DataPipelineBuild if not source_file.exists(): raise DatasetError( - f"The source file '{source_file}' is not found under {self._data_dir}." + self._name, f"The source file '{source_file}' is not found under {self._data_dir}." # fmt: skip ) if not target_file.exists(): raise DatasetError( - f"The target file '{target_file}' is not found under {self._data_dir}." + self._name, f"The target file '{target_file}' is not found under {self._data_dir}." # fmt: skip ) source_builder = read_text(source_file, rtrim=True, memory_map=True) @@ -506,35 +514,14 @@ def splits(self) -> set[str]: @override def directions(self, split: str) -> list[Direction]: - try: - directions, _ = self._splits[split] - except KeyError: - self._raise_split_error(split) - - return directions + directions_weights = self._splits.get(split) + if directions_weights is None: + raise SplitNotFoundError(self._name, split, self._splits.keys()) - def _raise_split_error(self, split: str) -> NoReturn: - raise ValueError( - f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits.keys()))}" - ) from None + return directions_weights[0] -@final -class GenericParallelTextDatasetLoader( - AbstractDatasetLoader[GenericParallelTextDataset] -): - @override - def _load(self, path: Path, card: AssetCard) -> GenericParallelTextDataset: - try: - return GenericParallelTextDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_parallel_text_dataset = GenericParallelTextDatasetLoader() - -load_parallel_text_dataset.register( - "generic_parallel_text", load_generic_parallel_text_dataset -) +def load_parallel_text_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> ParallelTextDataset: + return load_dataset(name_or_card, ParallelTextDataset, force=force) diff --git a/src/fairseq2/datasets/preference.py b/src/fairseq2/datasets/preference.py index e6d758367..315335371 100644 --- a/src/fairseq2/datasets/preference.py +++ b/src/fairseq2/datasets/preference.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path -from typing import Any, cast, final +from typing import Any, Final, cast, final import torch from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -29,8 +29,9 @@ from fairseq2.data.text import TextTokenizer from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching from fairseq2.datasets.data_reader import DataPipelineReader -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.static import load_dataset from fairseq2.datasets.utils import _load_files_and_weights +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.padding import get_seqs_and_padding_mask @@ -64,10 +65,9 @@ def create_reader( num_accumulate: int = 1, num_prefetch: int = 1, mask_source_tokens: bool = True, - src_encode_mode: str = "prompt", - tgt_encode_mode: str = "prompt_response", + source_encode_mode: str = "prompt", + target_encode_mode: str = "prompt_response", seed: int = 2, - **extras: Any, ) -> DataPipelineReader[PreferenceOptimizationBatch]: """Create a dataset reader. @@ -109,39 +109,43 @@ def create_reader( The number of batches to prefetch in background. :param mask_source_tokens: If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens. - :param src_encode_mode: - The mode to encode the prompt - :param tgt_encode_mode: - The mode to encode the target + :param source_encode_mode: + The mode to encode the source text. + :param target_encode_mode: + The mode to encode the target text. :param seed: The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. """ -load_preference_optimization_dataset = DelegatingDatasetLoader[ - PreferenceOptimizationDataset -]() - # TODO: FIX, INFER npc = 10 +GENERIC_PREFERENCE_OPTIMIZATION_DATASET_FAMILY: Final = ( + "generic_preference_optimization" +) + + @final class GenericPreferenceOptimizationDataset(PreferenceOptimizationDataset): """Represents a generic JSONL preference optimization dataset.""" + _name: str _files: Sequence[Path] _weights: Sequence[float] - def __init__(self, files: Sequence[Path], weights: Sequence[float]) -> None: + def __init__( + self, name: str, files: Sequence[Path], weights: Sequence[float] + ) -> None: """ :param files: The instruction files. :param weights: The weight of each file in ``files``. """ + self._name = name + if len(files) != len(weights): raise ValueError( f"The lengths of `files` and `weights` must match, but they are {len(files)} and {len(weights)} instead." @@ -150,12 +154,16 @@ def __init__(self, files: Sequence[Path], weights: Sequence[float]) -> None: self._files = files self._weights = weights - @classmethod - def from_path(cls, path: Path) -> GenericPreferenceOptimizationDataset: - """Load a :class:`PreferenceOptimizationDataset` from ``path``.""" - files, weights = _load_files_and_weights(path) + @staticmethod + def from_path( + path: Path, name: str | None = None + ) -> GenericPreferenceOptimizationDataset: + if name is None: + name = f"path:{path.name}" + + files, weights = _load_files_and_weights(name, path) - return GenericPreferenceOptimizationDataset(files, weights) + return GenericPreferenceOptimizationDataset(name, files, weights) @override def create_reader( @@ -174,10 +182,9 @@ def create_reader( num_accumulate: int = 1, num_prefetch: int = 1, mask_source_tokens: bool = True, - src_encode_mode: str = "prompt", - tgt_encode_mode: str = "prompt_response", + source_encode_mode: str = "prompt", + target_encode_mode: str = "prompt_response", seed: int = 2, - **extras: Any, ) -> DataPipelineReader[PreferenceOptimizationBatch]: if len(self._files) == 1: builder = self._read_jsonl(self._files[0], tokenizer) @@ -209,41 +216,41 @@ def create_reader( seed += gang.rank - # Encode prompt and target texts. - prompt_encoder = tokenizer.create_encoder(mode=src_encode_mode) - target_encoder = tokenizer.create_encoder(mode=tgt_encode_mode) + # Encode source and target texts. + source_encoder = tokenizer.create_encoder(mode=source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=target_encode_mode) - builder.map(prompt_encoder, selector="src", num_parallel_calls=npc) + builder.map(source_encoder, selector="src", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt_chosen", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt_rejected", num_parallel_calls=npc) def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: id_ = example.get("id", None) - prompt_indices = example["src"] + source_indices = example["src"] target_indices_chosen = example["tgt_chosen"] target_indices_rejected = example["tgt_rejected"] - indices_chosen = torch.cat([prompt_indices, target_indices_chosen]) - indices_rejected = torch.cat([prompt_indices, target_indices_rejected]) + indices_chosen = torch.cat([source_indices, target_indices_chosen]) + indices_rejected = torch.cat([source_indices, target_indices_rejected]) if mask_source_tokens: - prompt_len = len(prompt_indices) - target_mask_chosen = torch.arange(len(indices_chosen)) >= prompt_len - target_mask_rejected = torch.arange(len(indices_rejected)) >= prompt_len + source_len = len(source_indices) + target_mask_chosen = torch.arange(len(indices_chosen)) >= source_len + target_mask_rejected = torch.arange(len(indices_rejected)) >= source_len else: target_mask_chosen = torch.full([len(indices_chosen)], True) target_mask_rejected = torch.full([len(indices_rejected)], True) total_tokens = ( - 2 * len(prompt_indices) + 2 * len(source_indices) + len(target_indices_chosen) + len(target_indices_rejected) ) return { "id": id_, - "indices_prompt": prompt_indices, + "indices_prompt": source_indices, "indices_chosen": indices_chosen, "indices_rejected": indices_rejected, "target_mask_chosen": target_mask_chosen, @@ -278,7 +285,7 @@ def skip(example: dict[str, Any]) -> bool: # Bucket `batch_size` examples. builder.bucket(batching.batch_size, drop_remainder=drop_remainder) else: - raise RuntimeError(f"`{batching}` is not supported.") + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. if batch_shuffle_window != 1: @@ -337,6 +344,7 @@ def to_batch(example: dict[str, Any]) -> PreferenceOptimizationBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[PreferenceOptimizationBatch]( + self._name, pipeline, gang, num_accumulate=num_accumulate, @@ -355,26 +363,7 @@ def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuild return read_sequence(lines).map(json.loads, num_parallel_calls=npc) -@final -class GenericPreferenceOptimizationDatasetLoader( - AbstractDatasetLoader[GenericPreferenceOptimizationDataset] -): - @override - def _load( - self, path: Path, card: AssetCard - ) -> GenericPreferenceOptimizationDataset: - try: - return GenericPreferenceOptimizationDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_preference_optimization_dataset = ( - GenericPreferenceOptimizationDatasetLoader() -) - -load_preference_optimization_dataset.register( - "generic_preference_optimization", load_generic_preference_optimization_dataset -) +def load_preference_optimization_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> PreferenceOptimizationDataset: + return load_dataset(name_or_card, PreferenceOptimizationDataset, force=force) diff --git a/src/fairseq2/datasets/speech.py b/src/fairseq2/datasets/speech.py index 23d33b36f..0e4450a12 100644 --- a/src/fairseq2/datasets/speech.py +++ b/src/fairseq2/datasets/speech.py @@ -8,15 +8,16 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Literal, final +from typing import Final, final import torch from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.datasets.batching import Batching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.typing import DataType @@ -40,12 +41,11 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, - **extras: Any, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -96,8 +96,6 @@ def create_reader( The number of batches to prefetch in background. :param seed: The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. """ @abstractmethod @@ -105,19 +103,15 @@ def splits(self) -> set[str]: """Return the set of splits.""" -load_speech_dataset = DelegatingDatasetLoader[SpeechDataset]() +GENERIC_SPEECH_DATASET_FAMILY: Final = "generic_speech" @final class GenericSpeechDataset(SpeechDataset): """Represents a generic manifest-based Speech dataset.""" - def __init__(self) -> None: - pass - - @classmethod - def from_path(cls, path: Path) -> GenericSpeechDataset: - """Load a :class:`GenericSpeechDataset` from ``path``.""" + @staticmethod + def from_path(path: Path, name: str | None = None) -> GenericSpeechDataset: return GenericSpeechDataset() @override @@ -135,38 +129,26 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, cached_fd_count: int = 1000, - **extras: Any, ) -> DataPipelineReader[SequenceBatch]: """ :param cached_fd_count: The maximum number of file descriptors to keep open while reading audio files. """ - raise RuntimeError("not supported yet.") + raise NotSupportedError("not supported yet.") @override def splits(self) -> set[str]: return set() -@final -class GenericSpeechDatasetLoader(AbstractDatasetLoader[GenericSpeechDataset]): - @override - def _load(self, path: Path, card: AssetCard) -> GenericSpeechDataset: - try: - return GenericSpeechDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_speech_dataset = GenericSpeechDatasetLoader() - -load_speech_dataset.register("generic_speech", load_generic_speech_dataset) +def load_speech_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> SpeechDataset: + return load_dataset(name_or_card, SpeechDataset, force=force) diff --git a/src/fairseq2/datasets/static.py b/src/fairseq2/datasets/static.py new file mode 100644 index 000000000..392f03d44 --- /dev/null +++ b/src/fairseq2/datasets/static.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import TypeVar + +from fairseq2.assets import AssetCard +from fairseq2.context import get_runtime_context +from fairseq2.datasets.handler import ( + DatasetHandler, + DatasetNotFoundError, + get_dataset_family, +) +from fairseq2.error import ContractError + +DatasetT = TypeVar("DatasetT") + + +def load_dataset( + name_or_card: str | AssetCard, kls: type[DatasetT], *, force: bool = False +) -> DatasetT: + context = get_runtime_context() + + if isinstance(name_or_card, AssetCard): + card = name_or_card + else: + card = context.asset_store.retrieve_card(name_or_card) + + family = get_dataset_family(card) + + registry = context.get_registry(DatasetHandler) + + try: + handler = registry.get(family) + except LookupError: + raise DatasetNotFoundError(card.name) from None + + if not issubclass(handler.kls, kls): + raise TypeError( + f"The dataset is expected to be of type `{kls}`, but is of type `{type(handler.kls)}` instead." + ) + + dataset = handler.load(card, force=force) + + if not isinstance(dataset, kls): + raise ContractError( + f"The dataset is expected to be of type `{kls}`, but is of type `{type(handler.kls)}` instead." + ) + + return dataset diff --git a/src/fairseq2/datasets/text.py b/src/fairseq2/datasets/text.py index 5fab2b244..bcd1dbdc5 100644 --- a/src/fairseq2/datasets/text.py +++ b/src/fairseq2/datasets/text.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from functools import partial from pathlib import Path -from typing import Any, Literal, cast, final +from typing import Any, Final, cast, final from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( Collater, DataPipeline, @@ -24,8 +24,10 @@ ) from fairseq2.data.text import TextTokenEncoder, read_text from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.error import DatasetError +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.padding import get_seqs_and_padding_mask @@ -49,12 +51,11 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, - **extras: Any, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -103,34 +104,36 @@ def create_reader( The number of batches to prefetch in background. :param seed: The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. """ -load_text_dataset = DelegatingDatasetLoader[TextDataset]() - - # TODO: FIX, INFER npc = 10 +GENERIC_TEXT_DATASET_FAMILY: Final = "generic_text" + + @final class GenericTextDataset(TextDataset): """Represents a generic file-based text dataset.""" + _name: str _files: Sequence[Path] - def __init__(self, files: Sequence[Path]) -> None: + def __init__(self, name: str, files: Sequence[Path]) -> None: """ :param data_dir: The list of text files that represent the dataset. """ + self._name = name self._files = files @staticmethod - def from_path(path: Path) -> GenericTextDataset: - """Load a :class:`GenericTextDataset` from ``path``.""" + def from_path(path: Path, name: str | None = None) -> GenericTextDataset: + if name is None: + name = f"path:{path.name}" + path = path.expanduser().resolve() if not path.is_dir(): @@ -139,13 +142,13 @@ def from_path(path: Path) -> GenericTextDataset: try: files = [f for f in path.glob("**/*.txt") if not f.is_dir()] except OSError as ex: - raise RuntimeError( - f"The text files under {path} cannot be retrieved. See nested exception for details." + raise DatasetError( + name, f"The text files under the '{path}' directory cannot be retrieved. See the nested exception for details." # fmt: skip ) from ex files.sort() - return GenericTextDataset(files) + return GenericTextDataset(name, files) @override def create_reader( @@ -161,12 +164,11 @@ def create_reader( batch_shuffle_window: int = 1, drop_remainder: bool = False, sync_batches: bool = True, - sync_mode: Literal["until_first", "until_last"] = "until_first", + sync_mode: SyncMode = "until_first", max_num_batches: int | None = None, num_accumulate: int = 1, num_prefetch: int = 1, seed: int = 2, - **extras: Any, ) -> DataReader[SequenceBatch]: if len(self._files) == 1: builder = read_text(self._files[0], key="text", rtrim=True) @@ -224,7 +226,7 @@ def skip(example: dict[str, Any]) -> bool: # Bucket `batch_size` examples. builder.bucket(batching.batch_size, drop_remainder=drop_remainder) else: - raise RuntimeError(f"`{batching}` is not supported.") + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. if batch_shuffle_window != 1: @@ -249,6 +251,7 @@ def skip(example: dict[str, Any]) -> bool: pipeline = builder.map(f).and_return() return DataPipelineReader[SequenceBatch]( + self._name, pipeline, gang, num_accumulate=num_accumulate, @@ -266,18 +269,7 @@ def _to_batch(example: dict[str, Any], device: Device) -> SequenceBatch: return SequenceBatch(seqs, padding_mask, example=example) -@final -class GenericTextDatasetLoader(AbstractDatasetLoader[GenericTextDataset]): - @override - def _load(self, path: Path, card: AssetCard) -> GenericTextDataset: - try: - return GenericTextDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_text_dataset = GenericTextDatasetLoader() - -load_text_dataset.register("generic_text", load_generic_text_dataset) +def load_text_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> TextDataset: + return load_dataset(name_or_card, TextDataset, force=force) diff --git a/src/fairseq2/datasets/utils.py b/src/fairseq2/datasets/utils.py index a40bf0ca3..8f2a0b526 100644 --- a/src/fairseq2/datasets/utils.py +++ b/src/fairseq2/datasets/utils.py @@ -7,16 +7,15 @@ from __future__ import annotations from pathlib import Path -from typing import NoReturn import torch from fairseq2.datasets.error import DatasetError from fairseq2.gang import Gang, all_sum -from fairseq2.logging import LogWriter +from fairseq2.logging import log -def _min_num_batches(num_batches: int, gang: Gang, log: LogWriter) -> int: +def _min_num_batches(num_batches: int, gang: Gang) -> int: all_num_batches = torch.zeros((gang.size,), device=gang.device, dtype=torch.int64) inp = torch.tensor(num_batches, device=gang.device) @@ -45,7 +44,7 @@ def _sum_num_batches(num_batches: int, gang: Gang) -> int: return int(total_num_batches) -def _load_files_and_weights(path: Path) -> tuple[list[Path], list[float]]: +def _load_files_and_weights(name: str, path: Path) -> tuple[list[Path], list[float]]: path = path.expanduser().resolve() if not path.is_dir(): @@ -59,8 +58,8 @@ def _load_files_and_weights(path: Path) -> tuple[list[Path], list[float]]: except FileNotFoundError: content = None except OSError as ex: - raise RuntimeError( - f"{manifest_file} cannot be read. See nested exception for details." + raise DatasetError( + name, f"The '{manifest_file}' manifest file cannot be read. See the nested exception for details." # fmt: skip ) from ex # If the directory does not contain a MANIFEST file, treat all JSONL @@ -69,8 +68,8 @@ def _load_files_and_weights(path: Path) -> tuple[list[Path], list[float]]: try: files = list(path.glob("**/*.jsonl")) except OSError as ex: - raise RuntimeError( - f"The JSONL files under {path} cannot be retrieved. See nested exception for details." + raise DatasetError( + name, f"The JSONL files under the '{path}' directory cannot be retrieved. See the nested exception for details." # fmt: skip ) from ex weights = [1.0 for _ in range(len(files))] @@ -88,28 +87,28 @@ def _load_files_and_weights(path: Path) -> tuple[list[Path], list[float]]: # and its weight (e.g. number of examples). for idx, line in enumerate(content): - def raise_error() -> NoReturn: - raise DatasetError( - f"Each line in {manifest_file} must represent a path to a JSONL file and a weight, but line {idx} is '{line}' instead." + def error() -> DatasetError: + return DatasetError( + name, f"Each line in the '{manifest_file}' manifest file must represent a path to a JSONL file and a weight, but line {idx} is '{line}' instead." # fmt: skip ) fields = line.rstrip().split("\t") if len(fields) != 2: - raise_error() + raise error() file_path = fields[0].strip() if not file_path: - raise_error() + raise error() try: file = path.joinpath(file_path) except ValueError: - raise_error() + raise error() from None if not file.exists(): raise DatasetError( - f"The file '{file}' referred at line {idx} in {manifest_file} does not exist." + name, f"The '{file}' file referred at line {idx} in the '{manifest_file}' manifest file does not exist." # fmt: skip ) files.append(file) @@ -117,7 +116,7 @@ def raise_error() -> NoReturn: try: weight = float(fields[1].strip()) except ValueError: - raise_error() + raise error() from None weights.append(weight) diff --git a/src/fairseq2/recipes/hg/dataset.py b/src/fairseq2/recipes/hg/dataset.py index f4e70c8ac..933334262 100644 --- a/src/fairseq2/recipes/hg/dataset.py +++ b/src/fairseq2/recipes/hg/dataset.py @@ -74,9 +74,6 @@ def create_hf_reader( training. :param num_prefetch: The number of batches to prefetch in background. - :param extras: - The extra parameters specific to the dataset - implementation. """ if not has_datasets: raise ModuleNotFoundError( @@ -150,6 +147,7 @@ def skip(example: Example) -> bool: pipeline = builder.map(converter).and_return() return DataPipelineReader[BatchT]( + "hg", pipeline, gang, num_accumulate=num_accumulate, diff --git a/src/fairseq2/recipes/lm/eval_nll.py b/src/fairseq2/recipes/lm/eval_nll.py index 7963e7c1e..6e26f2e80 100644 --- a/src/fairseq2/recipes/lm/eval_nll.py +++ b/src/fairseq2/recipes/lm/eval_nll.py @@ -18,7 +18,7 @@ from fairseq2.datasets.batching import LengthBatching from fairseq2.datasets.instruction import ( GenericInstructionDataset, - load_generic_instruction_dataset, + load_instruction_dataset, ) from fairseq2.logging import get_log_writer from fairseq2.models import load_model @@ -143,7 +143,7 @@ def load_nll_evaluator( if dataset_card is not None: log.info("Loading {} preference optimization dataset.", dataset_card.name) - dataset = load_generic_instruction_dataset(dataset_card) + dataset = load_instruction_dataset(dataset_card) log.info("Dataset loaded.") else: diff --git a/src/fairseq2/recipes/lm/instruction_finetune.py b/src/fairseq2/recipes/lm/instruction_finetune.py index 2f9ec0290..699ee5485 100644 --- a/src/fairseq2/recipes/lm/instruction_finetune.py +++ b/src/fairseq2/recipes/lm/instruction_finetune.py @@ -454,8 +454,8 @@ def load_instruction_finetuner( batch_shuffle_window=config.batch_shuffle_window, num_accumulate=config.gradient_accumulation, num_prefetch=config.num_prefetch, - src_encode_mode=config.src_encode_mode, - tgt_encode_mode=config.tgt_encode_mode, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, seed=seed, ) except ValueError as ex: @@ -506,8 +506,8 @@ def load_instruction_finetuner( sync_mode="until_last", num_accumulate=config.gradient_accumulation, num_prefetch=config.num_prefetch, - src_encode_mode=config.src_encode_mode, - tgt_encode_mode=config.tgt_encode_mode, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, seed=seed, ) except ValueError as ex: diff --git a/src/fairseq2/recipes/lm/preference_finetune/recipe.py b/src/fairseq2/recipes/lm/preference_finetune/recipe.py index 6eedc2c40..e5d8e6c2a 100644 --- a/src/fairseq2/recipes/lm/preference_finetune/recipe.py +++ b/src/fairseq2/recipes/lm/preference_finetune/recipe.py @@ -406,14 +406,13 @@ def load_preference_finetuner( dp_gang, max_seq_len=config.max_seq_len, batching=batching, - max_num_tokens=config.max_num_tokens, example_shuffle_window=config.example_shuffle_window, batch_shuffle_window=config.batch_shuffle_window, num_accumulate=config.gradient_accumulation, num_prefetch=config.num_prefetch, mask_source_tokens=config.mask_source_tokens, - src_encode_mode=config.src_encode_mode, - tgt_encode_mode=config.tgt_encode_mode, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, seed=config.seed, ) except ValueError as ex: diff --git a/src/fairseq2/recipes/lm/text_generate.py b/src/fairseq2/recipes/lm/text_generate.py index ca99890cf..07752cb8a 100644 --- a/src/fairseq2/recipes/lm/text_generate.py +++ b/src/fairseq2/recipes/lm/text_generate.py @@ -288,7 +288,6 @@ def load_text_generator( batching=StaticBatching(config.batch_size), sync_mode="until_last", num_prefetch=config.num_prefetch, - seed=seed, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/setup/__init__.py b/src/fairseq2/setup/__init__.py index 28fd44ced..132b949df 100644 --- a/src/fairseq2/setup/__init__.py +++ b/src/fairseq2/setup/__init__.py @@ -9,6 +9,7 @@ from fairseq2.setup.assets import ( register_package_metadata_provider as register_package_metadata_provider, ) +from fairseq2.setup.datasets import register_dataset as register_dataset from fairseq2.setup.root import setup_fairseq2 as setup_fairseq2 from fairseq2.setup.root import setup_runtime_context as setup_runtime_context from fairseq2.setup.text_tokenizers import ( diff --git a/src/fairseq2/setup/datasets.py b/src/fairseq2/setup/datasets.py new file mode 100644 index 000000000..ea4bc54ad --- /dev/null +++ b/src/fairseq2/setup/datasets.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.datasets import DatasetHandler, DatasetLoader, StandardDatasetHandler +from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, GenericAsrDataset +from fairseq2.datasets.instruction import ( + GENERIC_INSTRUCTION_DATASET_FAMILY, + GenericInstructionDataset, +) +from fairseq2.datasets.parallel_text import ( + GENERIC_PARALLEL_TEXT_DATASET_FAMILY, + GenericParallelTextDataset, +) +from fairseq2.datasets.preference import ( + GENERIC_PREFERENCE_OPTIMIZATION_DATASET_FAMILY, + GenericPreferenceOptimizationDataset, +) +from fairseq2.datasets.speech import GENERIC_SPEECH_DATASET_FAMILY, GenericSpeechDataset +from fairseq2.datasets.text import GENERIC_TEXT_DATASET_FAMILY, GenericTextDataset + + +def _register_datasets(context: RuntimeContext) -> None: + register_dataset( + context, + GENERIC_ASR_DATASET_FAMILY, + kls=GenericAsrDataset, + loader=GenericAsrDataset.from_path, + ) + + register_dataset( + context, + GENERIC_INSTRUCTION_DATASET_FAMILY, + kls=GenericInstructionDataset, + loader=GenericInstructionDataset.from_path, + ) + + register_dataset( + context, + GENERIC_PARALLEL_TEXT_DATASET_FAMILY, + kls=GenericParallelTextDataset, + loader=GenericParallelTextDataset.from_path, + ) + + register_dataset( + context, + GENERIC_PREFERENCE_OPTIMIZATION_DATASET_FAMILY, + kls=GenericPreferenceOptimizationDataset, + loader=GenericPreferenceOptimizationDataset.from_path, + ) + + register_dataset( + context, + GENERIC_SPEECH_DATASET_FAMILY, + kls=GenericSpeechDataset, + loader=GenericSpeechDataset.from_path, + ) + + register_dataset( + context, + GENERIC_TEXT_DATASET_FAMILY, + kls=GenericTextDataset, + loader=GenericTextDataset.from_path, + ) + + +def register_dataset( + context: RuntimeContext, family: str, *, kls: type, loader: DatasetLoader +) -> None: + handler = StandardDatasetHandler(kls, loader, context.asset_download_manager) + + registry = context.get_registry(DatasetHandler) + + registry.register(family, handler) diff --git a/src/fairseq2/setup/root.py b/src/fairseq2/setup/root.py index 2806c8a7c..ee60f321b 100644 --- a/src/fairseq2/setup/root.py +++ b/src/fairseq2/setup/root.py @@ -13,6 +13,7 @@ from fairseq2.setup.chatbots import _register_chatbots from fairseq2.setup.clusters import _register_clusters from fairseq2.setup.config import _register_config_sections +from fairseq2.setup.datasets import _register_datasets from fairseq2.setup.generation import ( _register_beam_search_algorithms, _register_samplers, @@ -63,6 +64,7 @@ def setup_runtime_context() -> RuntimeContext: _register_chatbots(context) _register_clusters(context) _register_config_sections(context) + _register_datasets(context) _register_lr_schedulers(context) _register_optimizers(context) _register_samplers(context) diff --git a/src/fairseq2/setup/text_tokenizers.py b/src/fairseq2/setup/text_tokenizers.py index 3b9a1b3a6..f4ccadcc5 100644 --- a/src/fairseq2/setup/text_tokenizers.py +++ b/src/fairseq2/setup/text_tokenizers.py @@ -7,21 +7,31 @@ from __future__ import annotations from fairseq2.context import RuntimeContext -from fairseq2.data.text import ( +from fairseq2.data.text.tokenizers import ( + StandardTextTokenizerHandler, + TextTokenizerHandler, + TextTokenizerLoader, +) +from fairseq2.data.text.tokenizers.char_tokenizer import ( CHAR_TOKENIZER_FAMILY, + load_char_tokenizer, +) +from fairseq2.data.text.tokenizers.llama import ( LLAMA_TOKENIZER_FAMILY, + load_llama_tokenizer, +) +from fairseq2.data.text.tokenizers.mistral import ( MISTRAL_TOKENIZER_FAMILY, + load_mistral_tokenizer, +) +from fairseq2.data.text.tokenizers.nllb import ( NLLB_TOKENIZER_FAMILY, + load_nllb_tokenizer, +) +from fairseq2.data.text.tokenizers.s2t_transformer import ( S2T_TRANSFORMER_TOKENIZER_FAMILY, - StandardTextTokenizerHandler, - TextTokenizerHandler, - TextTokenizerLoader, + load_s2t_transformer_tokenizer, ) -from fairseq2.data.text.tokenizers.char_tokenizer import load_char_tokenizer -from fairseq2.data.text.tokenizers.llama import load_llama_tokenizer -from fairseq2.data.text.tokenizers.mistral import load_mistral_tokenizer -from fairseq2.data.text.tokenizers.nllb import load_nllb_tokenizer -from fairseq2.data.text.tokenizers.s2t_transformer import load_s2t_transformer_tokenizer def _register_text_tokenizers(context: RuntimeContext) -> None: diff --git a/tests/unit/data/text/test_sentencepiece.py b/tests/unit/data/text/test_sentencepiece.py index 76ae57ec7..ac4faafed 100644 --- a/tests/unit/data/text/test_sentencepiece.py +++ b/tests/unit/data/text/test_sentencepiece.py @@ -14,7 +14,7 @@ import pytest import torch -from fairseq2.data.text import ( +from fairseq2.data.text.tokenizers.sentencepiece import ( SentencePieceDecoder, SentencePieceEncoder, SentencePieceModel,