diff --git a/src/fairseq2/datasets/__init__.py b/src/fairseq2/datasets/__init__.py index bedc5bd3c..386794bab 100644 --- a/src/fairseq2/datasets/__init__.py +++ b/src/fairseq2/datasets/__init__.py @@ -6,9 +6,10 @@ from __future__ import annotations -from fairseq2.datasets.batching import Batching as Batching -from fairseq2.datasets.batching import LengthBatching as LengthBatching -from fairseq2.datasets.batching import StaticBatching as StaticBatching +from fairseq2.datasets.config import Batching as Batching +from fairseq2.datasets.config import DataReadOptions as DataReadOptions +from fairseq2.datasets.config import LengthBatching as LengthBatching +from fairseq2.datasets.config import StaticBatching as StaticBatching from fairseq2.datasets.data_reader import DataPipelineReader as DataPipelineReader from fairseq2.datasets.data_reader import DataReader as DataReader from fairseq2.datasets.data_reader import SyncMode as SyncMode diff --git a/src/fairseq2/datasets/asr.py b/src/fairseq2/datasets/asr.py index be62244bf..1bfc31a46 100644 --- a/src/fairseq2/datasets/asr.py +++ b/src/fairseq2/datasets/asr.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Any, Final, cast, final @@ -28,8 +29,13 @@ ) from fairseq2.data.audio import AudioDecoder from fairseq2.data.text import StrSplitter, TextTokenizer, read_text -from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader from fairseq2.datasets.error import DatasetError, SplitNotFoundError from fairseq2.datasets.static import load_dataset from fairseq2.error import NotSupportedError @@ -48,21 +54,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_audio_len: int, max_audio_len: int, batching: Batching, - *, - dtype: DataType = torch.float32, - min_audio_len: int = 1, - normalize_audio: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, + options: AsrReadOptions | None = None, ) -> DataReader[Seq2SeqBatch]: """Create a dataset reader. @@ -72,49 +67,16 @@ def create_reader( The tokenizer to encode target text. :param gang: The gang over which to shard the dataset. + :param min_audio_len: + The minimum audio length of each example. Examples shorter than this + value will be dropped. :param max_audio_len: - The maximum audio length of each example. Examples longer than - this value will be dropped. + The maximum audio length of each example. Examples longer than this + value will be dropped. :param batching: The batching strategy for returned examples. - :param dtype: - The data type of the decoded audio sequences. - :param min_audio_len: - The minimum audio length of each example. Examples shorter than - this value will be dropped. - :param normalize_audio: - If ``True``, normalizes audio to have zero mean and unit variance. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param sync_mode: - If ``until_first``, stops iteration when the first rank reaches end - of data. If ``until_last``, stops iteration when the last rank - reaches end of data; ranks that have already reached their end of - data will return an empty list of batches. - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. + :param options: + The read options. """ @abstractmethod @@ -122,6 +84,15 @@ def splits(self) -> set[str]: """Return the set of splits.""" +@dataclass +class AsrReadOptions(DataReadOptions): + dtype: DataType = torch.float32 + """The data type of the decoded audio sequences.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" + + # TODO: FIX, INFER npc = 10 @@ -174,22 +145,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_audio_len: int, max_audio_len: int, batching: Batching, - *, - dtype: DataType = torch.float32, - min_audio_len: int = 1, - normalize_audio: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - cached_fd_count: int = 1000, + options: AsrReadOptions | None = None, ) -> DataPipelineReader[Seq2SeqBatch]: """ :param cached_fd_count: @@ -199,13 +158,18 @@ def create_reader( if split not in self._splits: raise SplitNotFoundError(self._name, split, self._splits) + if options is None: + options = AsrReadOptions() + + seed = options.seed + audio_dir = self._retrieve_data_directory(split) builder = self._read_manifest(split) # Shuffle examples. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(example_shuffle_window, seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed) seed += 1 @@ -217,8 +181,8 @@ def create_reader( if isinstance(batching, LengthBatching): # Bucket by the audio length. bucket_sizes = create_bucket_sizes( - max_seq_len=max_audio_len, min_seq_len=min_audio_len, + max_seq_len=max_audio_len, max_num_elements=batching.max_num_elements, num_seqs_multiple_of=8, ) @@ -229,7 +193,7 @@ def create_reader( min_data_len=min_audio_len, skip_below_min_examples=True, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) elif isinstance(batching, StaticBatching): # Filter out out-of-range audios. @@ -241,23 +205,31 @@ def skip(example: dict[str, object]) -> bool: builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) else: raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) seed += 1 # Memory map audio files. + cached_fd_count = options.extras.get("cached_fd_count", 1) + if not isinstance(cached_fd_count, int): + raise TypeError( + f"`options.extras['cached_fd_count']` must be of type `int`, but is of type `{type(cached_fd_count)}` instead." + ) + file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count) builder.map(file_mapper, selector="[*].audio") # Decode audio. - audio_decoder = AudioDecoder(dtype=torch.float32 if normalize_audio else dtype) + audio_decoder = AudioDecoder( + dtype=torch.float32 if options.normalize_audio else options.dtype + ) builder.map(audio_decoder, selector="[*].audio.data") @@ -268,9 +240,9 @@ def normalize(waveform: Tensor) -> Tensor: with torch.no_grad(): waveform = layer_norm(waveform, waveform.shape) - return waveform.to(dtype) + return waveform.to(options.dtype) - if normalize_audio: + if options.normalize_audio: builder.map(normalize, selector="[*].audio.data.waveform") # Tokenize target text. @@ -288,11 +260,11 @@ def normalize(waveform: Tensor) -> Tensor: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `Seq2SeqBatch`. def to_batch(example: dict[str, Any]) -> Seq2SeqBatch: @@ -320,10 +292,10 @@ def to_batch(example: dict[str, Any]) -> Seq2SeqBatch: self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, - sync_mode=sync_mode, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) def _retrieve_data_directory(self, split: str) -> Path: diff --git a/src/fairseq2/datasets/batching.py b/src/fairseq2/datasets/batching.py deleted file mode 100644 index e977957c0..000000000 --- a/src/fairseq2/datasets/batching.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TypeAlias - - -@dataclass -class StaticBatching: - """Specifies batching where each batch has the same number of examples.""" - - batch_size: int - """The number of examples in each batch.""" - - -@dataclass -class LengthBatching: - """Specifies batching where each batch has a maximum number of elements.""" - - max_num_elements: int - """The maximum number of elements (e.g. tokens) in each batch.""" - - -Batching: TypeAlias = StaticBatching | LengthBatching diff --git a/src/fairseq2/datasets/config.py b/src/fairseq2/datasets/config.py new file mode 100644 index 000000000..8eda42644 --- /dev/null +++ b/src/fairseq2/datasets/config.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import MutableMapping +from dataclasses import dataclass, field +from typing import TypeAlias + +from fairseq2.datasets.data_reader import SyncMode + + +@dataclass +class StaticBatching: + """Specifies batching where each batch has the same number of examples.""" + + batch_size: int + """The number of examples in each batch.""" + + +@dataclass +class LengthBatching: + """Specifies batching where each batch has a maximum number of elements.""" + + max_num_elements: int + """The maximum number of elements (e.g. tokens) in each batch.""" + + +Batching: TypeAlias = StaticBatching | LengthBatching + + +@dataclass(kw_only=True) +class DataReadOptions: + example_shuffle_window: int = 1 + """ + The size of the sliding window for shuffling examples. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + batch_shuffle_window: int = 1 + """ + The size of the sliding window for shuffling batches. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + drop_remainder: bool = False + """ + If ``True``, drops the last set of batches if they have in total fewer + examples than requested. + """ + + sync_batches: bool = True + """ + If ``True``, ensures that each process in ``gang`` reads the same number of + batches. Typically used when the amount of data to be read can vary per + process (e.g. due to unbalanced sharding or non-static batching) and it is + critical for each process to iterate over the same number of batches (e.g. + during training). + """ + + sync_mode: SyncMode = "until_first" + """ + If ``until_first``, stops iteration on all ranks when one of the ranks + reaches its end of data. If ``until_last``, stops iteration when all ranks + reach their end of data; ranks that have already reached their end of data + will return an empty list of batches. + """ + + max_num_batches: int | None = None + """The maximum number of batches to return.""" + + num_accumulate: int = 1 + """ + The number of batches to accumulate in each iteration. Typically used with + gradient accumulation during training. + """ + + num_prefetch: int = 1 + """The number of batches to prefetch in background.""" + + seed: int = 2 + """The seed to initialize the random number generators used internally.""" + + extras: MutableMapping[str, object] = field(default_factory=dict) + """The reader-specific extra options.""" diff --git a/src/fairseq2/datasets/instruction.py b/src/fairseq2/datasets/instruction.py index 2822dfb05..db33a5120 100644 --- a/src/fairseq2/datasets/instruction.py +++ b/src/fairseq2/datasets/instruction.py @@ -9,6 +9,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Sequence +from dataclasses import dataclass from pathlib import Path from typing import Any, Final, cast, final @@ -26,8 +27,13 @@ read_sequence, ) from fairseq2.data.text import TextTokenizer -from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader from fairseq2.datasets.error import DatasetError, SplitNotFoundError from fairseq2.datasets.static import load_dataset from fairseq2.datasets.utils import _load_files_and_weights @@ -46,21 +52,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - source_encode_mode: str = "prompt", - target_encode_mode: str = "prompt_response", - seed: int = 2, + options: InstructionReadOptions | None = None, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -70,49 +65,16 @@ def create_reader( The tokenizer to encode text. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param sample: - If ``True``, instruction sources (e.g. files) will be sampled in - proportion to their weights. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param sync_mode: - If ``until_first``, stops iteration when the first rank reaches end - of data. If ``until_last``, stops iteration when the last rank - reaches end of data; ranks that have already reached their end of - data will return an empty list of batches. - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param source_encode_mode: - The mode to encode the source text. - :param target_encode_mode: - The mode to encode the target text. - :param seed: - The seed to initialize the random number generators used internally. + :param options: + The read options. """ @abstractmethod @@ -121,14 +83,10 @@ def create_prompt_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: StaticBatching, - *, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - num_prefetch: int = 1, - source_encode_mode: str = "prompt", + options: InstructionPromptReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: """Create a dataset reader for evaluation. @@ -143,19 +101,8 @@ def create_prompt_reader( this value will be dropped. :param batching: The batching strategy for returned examples. - :param drop_remainder: - If ``True``, drops the last batch if it has fewer examples than - requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param num_prefetch: - The number of batches to prefetch in background. - :param source_encode_mode: - The mode to encode the source text. + :param options: + The read options. """ @abstractmethod @@ -163,6 +110,27 @@ def splits(self) -> set[str]: """Return the set of splits.""" +@dataclass +class InstructionReadOptions(DataReadOptions): + sample: bool = False + """ + If ``True``, instruction sources (e.g. JSONL files) will be sampled in + proportion to their weights. + """ + + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + target_encode_mode: str = "prompt_response" + """The tokenizer mode to encode the target text.""" + + +@dataclass +class InstructionPromptReadOptions(DataReadOptions): + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + # TODO: FIX, INFER npc = 10 @@ -230,26 +198,20 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - source_encode_mode: str = "prompt", - target_encode_mode: str = "prompt_response", - seed: int = 2, + options: InstructionReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: files_weights = self._splits.get(split) if files_weights is None: raise SplitNotFoundError(self._name, split, self._splits.keys()) + if options is None: + options = InstructionReadOptions() + + seed = options.seed + files, weights = files_weights if len(files) == 1: @@ -262,7 +224,7 @@ def create_reader( pipelines.append(pipeline) - if sample: + if options.sample: builder = DataPipeline.sample(pipelines, weights=weights, seed=seed) seed += 1 @@ -270,8 +232,8 @@ def create_reader( builder = DataPipeline.concat(pipelines) # Shuffle files. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(shuffle_window=0, seed=seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed=seed) seed += 1 @@ -281,8 +243,8 @@ def create_reader( seed += gang.rank # Encode source and target texts. - source_encoder = tokenizer.create_encoder(mode=source_encode_mode) - target_encoder = tokenizer.create_encoder(mode=target_encode_mode) + source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) builder.map(source_encoder, selector="src", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt", num_parallel_calls=npc) @@ -303,31 +265,36 @@ def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( - max_seq_len=max_seq_len, max_num_elements=batching.max_num_elements + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + max_num_elements=batching.max_num_elements, ) # Bucket by the sequence length. builder.bucket_by_length( bucket_sizes, selector="indices", + min_data_len=min_seq_len, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) elif isinstance(batching, StaticBatching): # Filter out long examples. def skip(example: dict[str, Any]) -> bool: - return len(example["indices"]) <= max_seq_len + seq_len = len(example["indices"]) + + return seq_len >= min_seq_len and seq_len <= max_seq_len builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) else: raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed=seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed=seed) seed += 1 @@ -341,11 +308,11 @@ def skip(example: dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `SequenceBatch`. def to_batch(example: dict[str, Any]) -> SequenceBatch: @@ -363,10 +330,10 @@ def to_batch(example: dict[str, Any]) -> SequenceBatch: self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, - sync_mode=sync_mode, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) @override @@ -375,20 +342,19 @@ def create_prompt_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: StaticBatching, - *, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - num_prefetch: int = 1, - source_encode_mode: str = "prompt", + options: InstructionPromptReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: try: files, weights = self._splits[split] except KeyError: raise SplitNotFoundError(self._name, split, self._splits.keys()) from None + if options is None: + options = InstructionPromptReadOptions() + if len(files) == 1: builder = self._read_jsonl(files[0], tokenizer) else: @@ -405,7 +371,7 @@ def create_prompt_reader( builder.shard(gang.rank, gang.size, allow_uneven=True) # Encode source texts. - text_encoder = tokenizer.create_encoder(mode=source_encode_mode) + text_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) def encode(example: dict[str, Any]) -> dict[str, Any]: id_ = example.get("id") @@ -420,12 +386,14 @@ def encode(example: dict[str, Any]) -> dict[str, Any]: # Filter out long examples. def skip(example: dict[str, Any]) -> bool: - return len(example["indices"]) <= max_seq_len + seq_len = len(example["indices"]) + + return seq_len >= min_seq_len and seq_len <= max_seq_len builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) # Collate bucketed examples into a batch. collater = Collater(pad_value=tokenizer.vocab_info.pad_idx or 0) @@ -433,7 +401,7 @@ def skip(example: dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `SequenceBatch`. def to_batch(example: dict[str, Any]) -> SequenceBatch: @@ -449,9 +417,9 @@ def to_batch(example: dict[str, Any]) -> SequenceBatch: self._name, pipeline, gang, - drop_remainder=drop_remainder, - sync_batches=sync_batches, - sync_mode=sync_mode, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuilder: diff --git a/src/fairseq2/datasets/parallel_text.py b/src/fairseq2/datasets/parallel_text.py index 0a0d9b349..f53d9e42d 100644 --- a/src/fairseq2/datasets/parallel_text.py +++ b/src/fairseq2/datasets/parallel_text.py @@ -23,8 +23,13 @@ create_bucket_sizes, ) from fairseq2.data.text import TextTokenizer, read_text -from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader from fairseq2.datasets.error import DatasetError, SplitNotFoundError from fairseq2.datasets.static import load_dataset from fairseq2.error import NotSupportedError @@ -34,28 +39,6 @@ from fairseq2.typing import Device -@dataclass(unsafe_hash=True) # Due to FSDP, we cannot freeze. -class Direction: - """Represents the language direction of a parallel corpus.""" - - source_lang: str - """The source language code.""" - - target_lang: str - """The target language code.""" - - origin: str | None = None - """The origin of data. Typically used to indicate mined or synthetic data.""" - - def __repr__(self) -> str: - s = f"{self.source_lang}-{self.target_lang}" - - if self.origin: - s = f"{self.origin}/{s}" - - return s - - class ParallelTextDataset(ABC): """Represents a parallel text dataset.""" @@ -65,21 +48,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - direction: Direction | None = None, - min_seq_len: int = 1, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, + options: ParallelTextReadOptions | None = None, ) -> DataReader[Seq2SeqBatch]: """Create a dataset reader. @@ -89,49 +61,16 @@ def create_reader( The tokenizer to encode text. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param direction: - The direction to read. If ``None``, all directions will be read. - :param min_seq_len: - The minimum sequence length of each example. Examples shorter than - this value will be dropped. - :param sample: - If ``True``, corpora will be sampled in proportion to their weights. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param sync_mode: - If ``until_first``, stops iteration when the first rank reaches end - of data. If ``until_last``, stops iteration when the last rank - reaches end of data; ranks that have already reached their end of - data will return an empty list of batches. - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. + :param options: + The read options. """ @abstractmethod @@ -143,6 +82,37 @@ def directions(self, split: str) -> list[Direction]: """Return the directions included ``split``.""" +@dataclass +class ParallelTextReadOptions(DataReadOptions): + direction: Direction | None = None + """The direction to read. If ``None``, all directions will be read.""" + + sample: bool = False + """If ``True``, corpora will be sampled in proportion to their weights.""" + + +@dataclass(unsafe_hash=True) # Due to FSDP, we cannot freeze. +class Direction: + """Represents the language direction of a parallel corpus.""" + + source_lang: str + """The source language code.""" + + target_lang: str + """The target language code.""" + + origin: str | None = None + """The origin of data. Typically used to indicate mined or synthetic data.""" + + def __repr__(self) -> str: + s = f"{self.source_lang}-{self.target_lang}" + + if self.origin: + s = f"{self.origin}/{s}" + + return s + + # TODO: FIX, INFER npc = 10 @@ -292,29 +262,25 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - direction: Direction | None = None, - min_seq_len: int = 1, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, + options: ParallelTextReadOptions | None = None, ) -> DataPipelineReader[Seq2SeqBatch]: directions_weights = self._splits.get(split) if directions_weights is None: raise SplitNotFoundError(self._name, split, self._splits.keys()) + if options is None: + options = ParallelTextReadOptions() + + seed = options.seed + directions, weights = directions_weights # Determine the directions to read. + direction = options.direction + if direction is not None: if direction not in directions: raise ValueError( @@ -353,7 +319,7 @@ def create_reader( pipelines.append(pipeline) - if sample: + if options.sample: builder = DataPipeline.sample(pipelines, weights=weights, seed=seed) seed += 1 @@ -361,8 +327,8 @@ def create_reader( builder = DataPipeline.concat(pipelines) # Shuffle examples. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(example_shuffle_window, seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed) seed += 1 @@ -386,8 +352,8 @@ def encode(example: dict[str, Any]) -> dict[str, Any]: if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( - max_seq_len=max_seq_len, min_seq_len=min_seq_len, + max_seq_len=max_seq_len, max_num_elements=batching.max_num_elements, ) @@ -399,7 +365,7 @@ def encode(example: dict[str, Any]) -> dict[str, Any]: min_data_len=min_seq_len, skip_below_min_examples=True, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) elif isinstance(batching, StaticBatching): # Filter out out-of-range examples. @@ -414,13 +380,13 @@ def skip(example: dict[str, Any]) -> bool: builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) else: raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) seed += 1 @@ -430,11 +396,11 @@ def skip(example: dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) f = partial(self._to_batch, device=gang.device) @@ -444,10 +410,10 @@ def skip(example: dict[str, Any]) -> bool: self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, - sync_mode=sync_mode, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) def _read_direction(self, split: str, direction: Direction) -> DataPipelineBuilder: diff --git a/src/fairseq2/datasets/preference.py b/src/fairseq2/datasets/preference.py index 315335371..6cb1bcdbd 100644 --- a/src/fairseq2/datasets/preference.py +++ b/src/fairseq2/datasets/preference.py @@ -27,7 +27,12 @@ read_sequence, ) from fairseq2.data.text import TextTokenizer -from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) from fairseq2.datasets.data_reader import DataPipelineReader from fairseq2.datasets.static import load_dataset from fairseq2.datasets.utils import _load_files_and_weights @@ -53,21 +58,10 @@ def create_reader( self, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - mask_source_tokens: bool = True, - source_encode_mode: str = "prompt", - target_encode_mode: str = "prompt_response", - seed: int = 2, + options: PreferenceReadOptions | None = None, ) -> DataPipelineReader[PreferenceOptimizationBatch]: """Create a dataset reader. @@ -75,49 +69,40 @@ def create_reader( The tokenizer to encode text. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param sample: - If ``True``, instruction sources (e.g. files) will be sampled in - proportion to their weights. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param mask_source_tokens: - If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens. - :param source_encode_mode: - The mode to encode the source text. - :param target_encode_mode: - The mode to encode the target text. - :param seed: - The seed to initialize the random number generators used internally. + :param options: + The read options. """ +@dataclass +class PreferenceReadOptions(DataReadOptions): + sample: bool = False + """ + If ``True``, instruction sources (e.g. JSONL files) will be sampled in + proportion to their weights. + """ + + mask_source_tokens: bool = True + """ + If ``False``, calculates loss on the source tokens (prompt) as well as the + target tokens. + """ + + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + target_encode_mode: str = "prompt_response" + """The tokenizer mode to encode the target text.""" + + # TODO: FIX, INFER npc = 10 @@ -170,22 +155,16 @@ def create_reader( self, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - mask_source_tokens: bool = True, - source_encode_mode: str = "prompt", - target_encode_mode: str = "prompt_response", - seed: int = 2, + options: PreferenceReadOptions | None = None, ) -> DataPipelineReader[PreferenceOptimizationBatch]: + if options is None: + options = PreferenceReadOptions() + + seed = options.seed + if len(self._files) == 1: builder = self._read_jsonl(self._files[0], tokenizer) else: @@ -196,7 +175,7 @@ def create_reader( pipelines.append(pipeline) - if sample: + if options.sample: builder = DataPipeline.sample( pipelines, weights=self._weights, seed=seed ) @@ -206,8 +185,8 @@ def create_reader( builder = DataPipeline.concat(pipelines) # Shuffle files. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(shuffle_window=0, seed=seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed=seed) seed += 1 @@ -217,8 +196,8 @@ def create_reader( seed += gang.rank # Encode source and target texts. - source_encoder = tokenizer.create_encoder(mode=source_encode_mode) - target_encoder = tokenizer.create_encoder(mode=target_encode_mode) + source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) builder.map(source_encoder, selector="src", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt_chosen", num_parallel_calls=npc) @@ -234,7 +213,7 @@ def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: indices_chosen = torch.cat([source_indices, target_indices_chosen]) indices_rejected = torch.cat([source_indices, target_indices_rejected]) - if mask_source_tokens: + if options.mask_source_tokens: source_len = len(source_indices) target_mask_chosen = torch.arange(len(indices_chosen)) >= source_len target_mask_rejected = torch.arange(len(indices_rejected)) >= source_len @@ -262,15 +241,18 @@ def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( - max_seq_len=max_seq_len, max_num_elements=batching.max_num_elements + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + max_num_elements=batching.max_num_elements, ) # Bucket by the sequence length. builder.bucket_by_length( bucket_sizes, selector="total_tokens", + min_data_len=min_seq_len, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) elif isinstance(batching, StaticBatching): # Filter out long examples. @@ -278,18 +260,21 @@ def skip(example: dict[str, Any]) -> bool: chosen_len = len(example["indices_chosen"]) rejected_len = len(example["indices_rejected"]) - return chosen_len <= max_seq_len and rejected_len <= max_seq_len + if chosen_len > max_seq_len or rejected_len > max_seq_len: + return False + + return chosen_len >= min_seq_len and rejected_len >= min_seq_len builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) else: raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed=seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed=seed) seed += 1 @@ -304,11 +289,11 @@ def skip(example: dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `PreferenceOptimizationBatch`. def to_batch(example: dict[str, Any]) -> PreferenceOptimizationBatch: @@ -347,9 +332,9 @@ def to_batch(example: dict[str, Any]) -> PreferenceOptimizationBatch: self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, ) def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuilder: diff --git a/src/fairseq2/datasets/speech.py b/src/fairseq2/datasets/speech.py index 0e4450a12..66e891b09 100644 --- a/src/fairseq2/datasets/speech.py +++ b/src/fairseq2/datasets/speech.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Final, final @@ -14,8 +15,8 @@ from typing_extensions import override from fairseq2.assets import AssetCard -from fairseq2.datasets.batching import Batching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.config import Batching, DataReadOptions +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader from fairseq2.datasets.static import load_dataset from fairseq2.error import NotSupportedError from fairseq2.gang import Gang @@ -31,21 +32,10 @@ def create_reader( self, split: str, gang: Gang, + min_audio_len: int, max_audio_len: int, batching: Batching, - *, - dtype: DataType = torch.float32, - min_audio_len: int = 1, - normalize_audio: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, + options: SpeechReadOptions | None = None, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -53,49 +43,16 @@ def create_reader( The split to read. :param gang: The gang over which to shard the dataset. + :param min_audio_len: + The minimum audio length of each example. Examples shorter than this + value will be dropped. :param max_audio_len: - The maximum audio length of each example. Examples longer than - this value will be dropped. + The maximum audio length of each example. Examples longer than this + value will be dropped. :param batching: The batching strategy for returned examples. - :param dtype: - The data type of the decoded audio sequences. - :param min_audio_len: - The minimum audio length of each example. Examples shorter than - this value will be dropped. - :param normalize_audio: - If ``True``, normalizes audio to have zero mean and unit variance. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param sync_mode: - If ``until_first``, stops iteration when the first rank reaches end - of data. If ``until_last``, stops iteration when the last rank - reaches end of data; ranks that have already reached their end of - data will return an empty list of batches. - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. + :param options: + The read options. """ @abstractmethod @@ -103,6 +60,15 @@ def splits(self) -> set[str]: """Return the set of splits.""" +@dataclass +class SpeechReadOptions(DataReadOptions): + dtype: DataType = torch.float32 + """The data type of the decoded audio sequences.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" + + GENERIC_SPEECH_DATASET_FAMILY: Final = "generic_speech" @@ -119,28 +85,11 @@ def create_reader( self, split: str, gang: Gang, + min_audio_len: int, max_audio_len: int, batching: Batching, - *, - dtype: DataType = torch.float32, - min_audio_len: int = 1, - normalize_audio: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - cached_fd_count: int = 1000, + options: SpeechReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: - """ - :param cached_fd_count: - The maximum number of file descriptors to keep open while reading - audio files. - """ raise NotSupportedError("not supported yet.") @override diff --git a/src/fairseq2/datasets/text.py b/src/fairseq2/datasets/text.py index bcd1dbdc5..4a5fb8457 100644 --- a/src/fairseq2/datasets/text.py +++ b/src/fairseq2/datasets/text.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from dataclasses import dataclass from functools import partial from pathlib import Path from typing import Any, Final, cast, final @@ -23,8 +24,13 @@ read_sequence, ) from fairseq2.data.text import TextTokenEncoder, read_text -from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching -from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader from fairseq2.datasets.error import DatasetError from fairseq2.datasets.static import load_dataset from fairseq2.error import NotSupportedError @@ -43,19 +49,10 @@ def create_reader( text_encoder: TextTokenEncoder, pad_idx: int | None, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - min_seq_len: int = 1, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, + options: TextReadOptions | None = None, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -65,48 +62,24 @@ def create_reader( The index of the PAD symbol in the vocabulary of ``text_encoder``. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param min_seq_len: - The minimum sequence length of each example. Examples shorter than - this value will be dropped. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param sync_mode: - If ``until_first``, stops iteration when the first rank reaches end - of data. If ``until_last``, stops iteration when the last rank - reaches end of data; ranks that have already reached their end of - data will return an empty list of batches. - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. + :param options: + The read options. """ +@dataclass +class TextReadOptions(DataReadOptions): + pass + + # TODO: FIX, INFER npc = 10 @@ -156,20 +129,16 @@ def create_reader( text_encoder: TextTokenEncoder, pad_idx: int | None, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: Batching, - *, - min_seq_len: int = 1, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - sync_mode: SyncMode = "until_first", - max_num_batches: int | None = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, + options: TextReadOptions | None = None, ) -> DataReader[SequenceBatch]: + if options is None: + options = TextReadOptions() + + seed = options.seed + if len(self._files) == 1: builder = read_text(self._files[0], key="text", rtrim=True) else: @@ -181,8 +150,8 @@ def read_file(file: Path) -> DataPipeline: builder.yield_from(read_file) # Shuffle examples. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(example_shuffle_window, seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed) seed += 1 @@ -212,7 +181,7 @@ def encode(example: dict[str, Any]) -> dict[str, Any]: min_data_len=min_seq_len, skip_below_min_examples=True, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) elif isinstance(batching, StaticBatching): # Filter out out-of-range examples. @@ -224,13 +193,13 @@ def skip(example: dict[str, Any]) -> bool: builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) else: raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) seed += 1 @@ -240,11 +209,11 @@ def skip(example: dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) f = partial(self._to_batch, device=gang.device) @@ -254,10 +223,10 @@ def skip(example: dict[str, Any]) -> bool: self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, - sync_mode=sync_mode, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) @staticmethod diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 8d0e97df1..fbcab14a8 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -19,7 +19,7 @@ from fairseq2.config_registry import ConfigRegistry from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import TextTokenizer, load_text_tokenizer -from fairseq2.datasets.batching import StaticBatching +from fairseq2.datasets import StaticBatching from fairseq2.logging import get_log_writer from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.sequence import SequenceBatch diff --git a/src/fairseq2/recipes/hg/dataset.py b/src/fairseq2/recipes/hg/dataset.py index 933334262..5ac62ee9b 100644 --- a/src/fairseq2/recipes/hg/dataset.py +++ b/src/fairseq2/recipes/hg/dataset.py @@ -10,7 +10,7 @@ from typing import Any from fairseq2.data.data_pipeline import Collater, create_bucket_sizes, read_sequence -from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching +from fairseq2.datasets import Batching, LengthBatching, StaticBatching from fairseq2.datasets.data_reader import BatchT, DataPipelineReader from fairseq2.gang import Gang diff --git a/src/fairseq2/recipes/lm/eval_nll.py b/src/fairseq2/recipes/lm/eval_nll.py index 6e26f2e80..5465aa742 100644 --- a/src/fairseq2/recipes/lm/eval_nll.py +++ b/src/fairseq2/recipes/lm/eval_nll.py @@ -15,9 +15,10 @@ from fairseq2.assets import AssetNotFoundError from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import load_text_tokenizer -from fairseq2.datasets.batching import LengthBatching +from fairseq2.datasets import LengthBatching from fairseq2.datasets.instruction import ( GenericInstructionDataset, + InstructionReadOptions, load_instruction_dataset, ) from fairseq2.logging import get_log_writer @@ -85,6 +86,9 @@ class NLLEvalConfig: max_num_tokens: int = 8192 * 2 """The maximum number of tokens per batch.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 8192 """The maximum sequence length.""" @@ -208,12 +212,7 @@ def load_nll_evaluator( criterion = InstructionFinetuneCriterion(dp_model) unit = InstructionValidUnit(criterion, dp_gang) - data_reader = dataset.create_reader( - config.valid_split, - tokenizer, - dp_gang, - config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), + options = InstructionReadOptions( example_shuffle_window=config.example_shuffle_window, batch_shuffle_window=config.batch_shuffle_window, sync_mode="until_last", @@ -222,6 +221,16 @@ def load_nll_evaluator( seed=seed, ) + data_reader = dataset.create_reader( + config.valid_split, + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + LengthBatching(config.max_num_tokens), + options=options, + ) + # TODO: Fix once we support static mixed precision on one device. if config.mixed_precision == "static": amp = root_gang.size == 1 or config.data_parallelism != "fsdp" diff --git a/src/fairseq2/recipes/lm/instruction_finetune.py b/src/fairseq2/recipes/lm/instruction_finetune.py index 699ee5485..90a391a16 100644 --- a/src/fairseq2/recipes/lm/instruction_finetune.py +++ b/src/fairseq2/recipes/lm/instruction_finetune.py @@ -23,6 +23,7 @@ from fairseq2.datasets import Batching, LengthBatching, StaticBatching from fairseq2.datasets.instruction import ( GenericInstructionDataset, + InstructionReadOptions, load_instruction_dataset, ) from fairseq2.gang import Gang @@ -74,6 +75,9 @@ class InstructionFinetuneConfig: valid_split: str | None = None """The name of the valid data split.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 8192 """The maximum sequence length.""" @@ -444,12 +448,7 @@ def load_instruction_finetuner( else: batching = LengthBatching(config.max_num_tokens) - data_reader = dataset.create_reader( - config.train_split, - tokenizer, - dp_gang, - config.max_seq_len, - batching=batching, + options = InstructionReadOptions( example_shuffle_window=config.example_shuffle_window, batch_shuffle_window=config.batch_shuffle_window, num_accumulate=config.gradient_accumulation, @@ -458,6 +457,16 @@ def load_instruction_finetuner( target_encode_mode=config.tgt_encode_mode, seed=seed, ) + + data_reader = dataset.create_reader( + config.train_split, + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + batching, + options, + ) except ValueError as ex: raise ValueError( "The data reader cannot be initialized. See nested exception for details." @@ -494,21 +503,26 @@ def load_instruction_finetuner( max_num_tokens = config.max_num_valid_tokens or config.max_num_tokens + options = InstructionReadOptions( + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + sync_mode="until_last", + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, + seed=seed, + ) + try: valid_data_reader = dataset.create_reader( config.valid_split, tokenizer, dp_gang, + config.min_seq_len, config.max_seq_len, - batching=LengthBatching(max_num_tokens), - example_shuffle_window=config.example_shuffle_window, - batch_shuffle_window=config.batch_shuffle_window, - sync_mode="until_last", - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - source_encode_mode=config.src_encode_mode, - target_encode_mode=config.tgt_encode_mode, - seed=seed, + LengthBatching(max_num_tokens), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/lm/preference_finetune/recipe.py b/src/fairseq2/recipes/lm/preference_finetune/recipe.py index e5d8e6c2a..12ad37791 100644 --- a/src/fairseq2/recipes/lm/preference_finetune/recipe.py +++ b/src/fairseq2/recipes/lm/preference_finetune/recipe.py @@ -21,6 +21,7 @@ from fairseq2.datasets.preference import ( GenericPreferenceOptimizationDataset, PreferenceOptimizationBatch, + PreferenceReadOptions, load_preference_optimization_dataset, ) from fairseq2.logging import get_log_writer @@ -55,6 +56,10 @@ class PreferenceFinetuneConfig: dataset: AssetReference = "gsm8k_dpo_data" # TODO: change! """The name, path, or path to the asset card of the preference optimization dataset.""" + min_seq_len: int = 1 + """The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Shorter sequences will be dropped.""" + max_seq_len: int = 8192 """The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. Longer sequences will be dropped.""" @@ -400,20 +405,25 @@ def load_preference_finetuner( else: batching = LengthBatching(config.max_num_tokens) + options = PreferenceReadOptions( + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + mask_source_tokens=config.mask_source_tokens, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, + seed=config.seed, + ) + try: data_reader = dataset.create_reader( tokenizer, dp_gang, - max_seq_len=config.max_seq_len, - batching=batching, - example_shuffle_window=config.example_shuffle_window, - batch_shuffle_window=config.batch_shuffle_window, - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - mask_source_tokens=config.mask_source_tokens, - source_encode_mode=config.src_encode_mode, - target_encode_mode=config.tgt_encode_mode, - seed=config.seed, + config.min_seq_len, + config.max_seq_len, + batching, + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/lm/text_generate.py b/src/fairseq2/recipes/lm/text_generate.py index 07752cb8a..c020281bb 100644 --- a/src/fairseq2/recipes/lm/text_generate.py +++ b/src/fairseq2/recipes/lm/text_generate.py @@ -21,6 +21,7 @@ from fairseq2.datasets import StaticBatching from fairseq2.datasets.instruction import ( GenericInstructionDataset, + InstructionPromptReadOptions, load_instruction_dataset, ) from fairseq2.gang import Gang @@ -56,6 +57,9 @@ class TextGenerateConfig: split: str = "default" """The name of the data split.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 8192 """The maximum sequence length.""" @@ -279,15 +283,19 @@ def load_text_generator( json_output_stream=json_output_fp, ) + options = InstructionPromptReadOptions( + sync_mode="until_last", num_prefetch=config.num_prefetch + ) + try: data_reader = dataset.create_prompt_reader( config.split, tokenizer, dp_gang, + config.min_seq_len, config.max_seq_len, - batching=StaticBatching(config.batch_size), - sync_mode="until_last", - num_prefetch=config.num_prefetch, + StaticBatching(config.batch_size), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/mt/eval.py b/src/fairseq2/recipes/mt/eval.py index ef6236b3b..f42f2ea50 100644 --- a/src/fairseq2/recipes/mt/eval.py +++ b/src/fairseq2/recipes/mt/eval.py @@ -21,6 +21,7 @@ from fairseq2.datasets.parallel_text import ( Direction, GenericParallelTextDataset, + ParallelTextReadOptions, load_parallel_text_dataset, ) from fairseq2.gang import Gang @@ -62,6 +63,9 @@ class MTEvalConfig: split: str = "test" """The name of the test data split.""" + min_seq_len: int = 1 + """The maximum sequence length.""" + max_seq_len: int = 512 """The maximum sequence length.""" @@ -211,17 +215,22 @@ def load_mt_evaluator( units.append(loss_unit) + options = ParallelTextReadOptions( + direction=direction, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.split, tokenizer, gang, + config.min_seq_len, config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), - direction=direction, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + LengthBatching(config.max_num_tokens), + options, ) except ValueError as ex: raise ValueError( @@ -285,17 +294,22 @@ def load_mt_evaluator( units.append(score_unit) + options = ParallelTextReadOptions( + direction=direction, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.split, tokenizer, gang, + config.min_seq_len, config.max_seq_len, - batching=StaticBatching(config.generator_batch_size), - direction=direction, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + StaticBatching(config.generator_batch_size), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/mt/train.py b/src/fairseq2/recipes/mt/train.py index 30381c6c0..5b71cc60c 100644 --- a/src/fairseq2/recipes/mt/train.py +++ b/src/fairseq2/recipes/mt/train.py @@ -21,6 +21,7 @@ from fairseq2.datasets import LengthBatching, StaticBatching from fairseq2.datasets.parallel_text import ( GenericParallelTextDataset, + ParallelTextReadOptions, load_parallel_text_dataset, ) from fairseq2.gang import Gang @@ -68,6 +69,9 @@ class MTTrainConfig: valid_split: str = "valid" """The name of the valid data split.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 512 """The maximum sequence length.""" @@ -305,19 +309,24 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB # Initialize the train unit. unit = MTTrainUnit(criterion, gang) + options = ParallelTextReadOptions( + sample=True, + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.split, tokenizer, gang, + config.min_seq_len, config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), - sample=True, - example_shuffle_window=config.example_shuffle_window, - batch_shuffle_window=config.batch_shuffle_window, - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - seed=seed, + LengthBatching(config.max_num_tokens), + options, ) except ValueError as ex: raise ValueError( @@ -373,17 +382,22 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB valid_units.append(valid_loss_unit) + options = ParallelTextReadOptions( + direction=direction, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: valid_data_reader = dataset.create_reader( config.valid_split, tokenizer, gang, + config.min_seq_len, config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), - direction=direction, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + LengthBatching(config.max_num_tokens), + options, ) except ValueError as ex: raise ValueError( @@ -402,17 +416,22 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB valid_units.append(valid_score_unit) + options = ParallelTextReadOptions( + direction=direction, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: valid_data_reader = dataset.create_reader( config.valid_split, tokenizer, gang, + config.min_seq_len, config.max_seq_len, - batching=StaticBatching(config.generator_batch_size), - direction=direction, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + StaticBatching(config.generator_batch_size), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/mt/translate.py b/src/fairseq2/recipes/mt/translate.py index ed788cf3d..4965d2020 100644 --- a/src/fairseq2/recipes/mt/translate.py +++ b/src/fairseq2/recipes/mt/translate.py @@ -18,7 +18,11 @@ from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import TextTokenizer, load_text_tokenizer from fairseq2.datasets import StaticBatching -from fairseq2.datasets.text import GenericTextDataset, load_text_dataset +from fairseq2.datasets.text import ( + GenericTextDataset, + TextReadOptions, + load_text_dataset, +) from fairseq2.gang import Gang from fairseq2.generation import ( BeamSearchConfig, @@ -59,6 +63,9 @@ class TextTranslateConfig: target_lang: str = "deu_Latn" """The code of the language to translate to.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 512 """The maximum sequence length.""" @@ -225,16 +232,19 @@ def load_text_translator( seed = config.seed + options = TextReadOptions( + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed + ) + try: data_reader = dataset.create_reader( text_encoder, tokenizer.vocab_info.pad_idx, gang, + config.min_seq_len, config.max_seq_len, - batching=StaticBatching(config.batch_size), - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + StaticBatching(config.batch_size), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/wav2vec2/asr/eval.py b/src/fairseq2/recipes/wav2vec2/asr/eval.py index 65c0ac43e..ddb2db985 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/eval.py +++ b/src/fairseq2/recipes/wav2vec2/asr/eval.py @@ -18,7 +18,7 @@ from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import load_text_tokenizer from fairseq2.datasets import LengthBatching -from fairseq2.datasets.asr import GenericAsrDataset, load_asr_dataset +from fairseq2.datasets.asr import AsrReadOptions, GenericAsrDataset, load_asr_dataset from fairseq2.gang import Gang from fairseq2.logging import get_log_writer from fairseq2.models import load_model @@ -206,19 +206,23 @@ def load_wav2vec2_asr_evaluator( seed = config.seed + options = AsrReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.split, tokenizer, gang, - batching=LengthBatching(config.max_num_elements), - dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, - normalize_audio=config.normalize_audio, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/wav2vec2/asr/train.py b/src/fairseq2/recipes/wav2vec2/asr/train.py index 355f07cc5..9dae58d9d 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/train.py +++ b/src/fairseq2/recipes/wav2vec2/asr/train.py @@ -19,7 +19,7 @@ from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import load_text_tokenizer from fairseq2.datasets import LengthBatching -from fairseq2.datasets.asr import GenericAsrDataset, load_asr_dataset +from fairseq2.datasets.asr import AsrReadOptions, GenericAsrDataset, load_asr_dataset from fairseq2.gang import Gang from fairseq2.logging import get_log_writer from fairseq2.models import create_model @@ -385,21 +385,25 @@ def load_wav2vec2_asr_trainer( criterion, gang, freeze_encoder_for_n_steps=config.freeze_encoder_for_n_steps ) + options = AsrReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.train_split, tokenizer, gang, - batching=LengthBatching(config.max_num_elements), - dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, - normalize_audio=config.normalize_audio, - example_shuffle_window=config.example_shuffle_window, - batch_shuffle_window=config.batch_shuffle_window, - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - seed=seed, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, ) except ValueError as ex: raise ValueError( @@ -439,19 +443,23 @@ def load_wav2vec2_asr_trainer( # Initialize the validation unit. valid_unit = Wav2Vec2AsrEvalUnit(valid_criterion, gang) + options = AsrReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: valid_data_reader = dataset.create_reader( config.valid_split, tokenizer, gang, - batching=LengthBatching(config.max_num_elements), - dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, - normalize_audio=config.normalize_audio, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/wav2vec2/eval.py b/src/fairseq2/recipes/wav2vec2/eval.py index 6f180fb00..0481c3fe1 100644 --- a/src/fairseq2/recipes/wav2vec2/eval.py +++ b/src/fairseq2/recipes/wav2vec2/eval.py @@ -16,8 +16,12 @@ from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider from fairseq2.config_registry import ConfigRegistry -from fairseq2.datasets.batching import LengthBatching -from fairseq2.datasets.speech import GenericSpeechDataset, load_speech_dataset +from fairseq2.datasets import LengthBatching +from fairseq2.datasets.speech import ( + GenericSpeechDataset, + SpeechReadOptions, + load_speech_dataset, +) from fairseq2.gang import Gang from fairseq2.logging import get_log_writer from fairseq2.models import load_model @@ -175,18 +179,22 @@ def load_wav2vec2_evaluator( seed = config.seed + options = SpeechReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.split, gang, + config.min_audio_len, config.max_audio_len, - batching=LengthBatching(config.max_num_elements), - dtype=config.dtype, - min_audio_len=config.min_audio_len, - normalize_audio=config.normalize_audio, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + LengthBatching(config.max_num_elements), + options, ) except ValueError as ex: raise ValueError( diff --git a/src/fairseq2/recipes/wav2vec2/train.py b/src/fairseq2/recipes/wav2vec2/train.py index 891ec202c..48955f38d 100644 --- a/src/fairseq2/recipes/wav2vec2/train.py +++ b/src/fairseq2/recipes/wav2vec2/train.py @@ -17,8 +17,12 @@ from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager from fairseq2.config_registry import ConfigRegistry -from fairseq2.datasets.batching import LengthBatching -from fairseq2.datasets.speech import GenericSpeechDataset, load_speech_dataset +from fairseq2.datasets import LengthBatching +from fairseq2.datasets.speech import ( + GenericSpeechDataset, + SpeechReadOptions, + load_speech_dataset, +) from fairseq2.gang import Gang from fairseq2.logging import get_log_writer from fairseq2.models import create_model @@ -297,19 +301,23 @@ def load_wav2vec2_trainer( # Initialize the train unit. unit = Wav2Vec2TrainUnit(criterion, gang) + options = SpeechReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: data_reader = dataset.create_reader( config.train_split, gang, - batching=LengthBatching(config.max_num_elements), - dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, - normalize_audio=config.normalize_audio, - batch_shuffle_window=config.batch_shuffle_window, - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - seed=seed, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, ) except ValueError as ex: raise ValueError( @@ -344,18 +352,22 @@ def load_wav2vec2_trainer( # Initialize the validation unit. valid_unit = Wav2Vec2EvalUnit(criterion, gang) + options = SpeechReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + try: valid_data_reader = dataset.create_reader( config.valid_split, gang, - batching=LengthBatching(config.max_num_elements), - dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, - normalize_audio=config.normalize_audio, - sync_mode="until_last", - num_prefetch=config.num_prefetch, - seed=seed, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, ) except ValueError as ex: raise ValueError(