Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dataset configs #956

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/fairseq2/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from __future__ import annotations

from fairseq2.datasets.batching import Batching as Batching
from fairseq2.datasets.batching import LengthBatching as LengthBatching
from fairseq2.datasets.batching import StaticBatching as StaticBatching
from fairseq2.datasets.config import Batching as Batching
from fairseq2.datasets.config import DataReadOptions as DataReadOptions
from fairseq2.datasets.config import LengthBatching as LengthBatching
from fairseq2.datasets.config import StaticBatching as StaticBatching
from fairseq2.datasets.data_reader import DataPipelineReader as DataPipelineReader
from fairseq2.datasets.data_reader import DataReader as DataReader
from fairseq2.datasets.data_reader import SyncMode as SyncMode
Expand Down
144 changes: 58 additions & 86 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Final, cast, final

Expand All @@ -28,8 +29,13 @@
)
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode
from fairseq2.datasets.config import (
Batching,
DataReadOptions,
LengthBatching,
StaticBatching,
)
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader
from fairseq2.datasets.error import DatasetError, SplitNotFoundError
from fairseq2.datasets.static import load_dataset
from fairseq2.error import NotSupportedError
Expand All @@ -48,21 +54,10 @@ def create_reader(
split: str,
tokenizer: TextTokenizer,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
batching: Batching,
*,
dtype: DataType = torch.float32,
min_audio_len: int = 1,
normalize_audio: bool = False,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
drop_remainder: bool = False,
sync_batches: bool = True,
sync_mode: SyncMode = "until_first",
max_num_batches: int | None = None,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
options: AsrReadOptions | None = None,
) -> DataReader[Seq2SeqBatch]:
"""Create a dataset reader.

Expand All @@ -72,56 +67,32 @@ def create_reader(
The tokenizer to encode target text.
:param gang:
The gang over which to shard the dataset.
:param min_audio_len:
The minimum audio length of each example. Examples shorter than this
value will be dropped.
:param max_audio_len:
The maximum audio length of each example. Examples longer than
this value will be dropped.
The maximum audio length of each example. Examples longer than this
value will be dropped.
:param batching:
The batching strategy for returned examples.
:param dtype:
The data type of the decoded audio sequences.
:param min_audio_len:
The minimum audio length of each example. Examples shorter than
this value will be dropped.
:param normalize_audio:
If ``True``, normalizes audio to have zero mean and unit variance.
:param example_shuffle_window:
The size of the sliding window for shuffling examples. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by
loading the entire dataset.
:param batch_shuffle_window:
The size of the sliding window for shuffling batches. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by
loading the entire dataset.
:param drop_remainder:
If ``True``, drops the last set of batches if they have in total
fewer examples than requested.
:param sync_batches:
If ``True``, ensures that each process in ``gang`` reads the same
number of batches. Typically used when the amount of data to be read
can vary per process (e.g. due to unbalanced sharding or non-static
batching) and it is critical for each process to iterate over the
same number of batches (e.g. during training).
:param sync_mode:
If ``until_first``, stops iteration when the first rank reaches end
of data. If ``until_last``, stops iteration when the last rank
reaches end of data; ranks that have already reached their end of
data will return an empty list of batches.
:param max_num_batches:
The maximum number of batches to return.
:param num_accumulate:
The number of batches to accumulate in each iteration. Typically
used with gradient accumulation during training.
:param num_prefetch:
The number of batches to prefetch in background.
:param seed:
The seed to initialize the random number generators used internally.
:param options:
The read options.
"""

@abstractmethod
def splits(self) -> set[str]:
"""Return the set of splits."""


@dataclass
class AsrReadOptions(DataReadOptions):
dtype: DataType = torch.float32
"""The data type of the decoded audio sequences."""

normalize_audio: bool = False
"""If ``True``, normalizes audio to have zero mean and unit variance."""


# TODO: FIX, INFER
npc = 10

Expand Down Expand Up @@ -174,22 +145,10 @@ def create_reader(
split: str,
tokenizer: TextTokenizer,
gang: Gang,
min_audio_len: int,
max_audio_len: int,
batching: Batching,
*,
dtype: DataType = torch.float32,
min_audio_len: int = 1,
normalize_audio: bool = False,
example_shuffle_window: int = 1,
batch_shuffle_window: int = 1,
drop_remainder: bool = False,
sync_batches: bool = True,
sync_mode: SyncMode = "until_first",
max_num_batches: int | None = None,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
cached_fd_count: int = 1000,
options: AsrReadOptions | None = None,
) -> DataPipelineReader[Seq2SeqBatch]:
"""
:param cached_fd_count:
Expand All @@ -199,13 +158,18 @@ def create_reader(
if split not in self._splits:
raise SplitNotFoundError(self._name, split, self._splits)

if options is None:
options = AsrReadOptions()

seed = options.seed

audio_dir = self._retrieve_data_directory(split)

builder = self._read_manifest(split)

# Shuffle examples. Must be consistent across all processes.
if example_shuffle_window != 1:
builder.shuffle(example_shuffle_window, seed)
if options.example_shuffle_window != 1:
builder.shuffle(options.example_shuffle_window, seed)

seed += 1

Expand All @@ -217,8 +181,8 @@ def create_reader(
if isinstance(batching, LengthBatching):
# Bucket by the audio length.
bucket_sizes = create_bucket_sizes(
max_seq_len=max_audio_len,
min_seq_len=min_audio_len,
max_seq_len=max_audio_len,
max_num_elements=batching.max_num_elements,
num_seqs_multiple_of=8,
)
Expand All @@ -229,7 +193,7 @@ def create_reader(
min_data_len=min_audio_len,
skip_below_min_examples=True,
skip_above_max_examples=True,
drop_remainder=drop_remainder,
drop_remainder=options.drop_remainder,
)
elif isinstance(batching, StaticBatching):
# Filter out out-of-range audios.
Expand All @@ -241,23 +205,31 @@ def skip(example: dict[str, object]) -> bool:
builder.filter(skip)

# Bucket `batch_size` examples.
builder.bucket(batching.batch_size, drop_remainder=drop_remainder)
builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder)
else:
raise NotSupportedError(f"`{batching}` is not supported.")

# Shuffle buckets.
if batch_shuffle_window != 1:
builder.shuffle(batch_shuffle_window, seed)
if options.batch_shuffle_window != 1:
builder.shuffle(options.batch_shuffle_window, seed)

seed += 1

# Memory map audio files.
cached_fd_count = options.extras.get("cached_fd_count", 1)
if not isinstance(cached_fd_count, int):
raise TypeError(
f"`options.extras['cached_fd_count']` must be of type `int`, but is of type `{type(cached_fd_count)}` instead."
)

file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count)

builder.map(file_mapper, selector="[*].audio")

# Decode audio.
audio_decoder = AudioDecoder(dtype=torch.float32 if normalize_audio else dtype)
audio_decoder = AudioDecoder(
dtype=torch.float32 if options.normalize_audio else options.dtype
)

builder.map(audio_decoder, selector="[*].audio.data")

Expand All @@ -268,9 +240,9 @@ def normalize(waveform: Tensor) -> Tensor:
with torch.no_grad():
waveform = layer_norm(waveform, waveform.shape)

return waveform.to(dtype)
return waveform.to(options.dtype)

if normalize_audio:
if options.normalize_audio:
builder.map(normalize, selector="[*].audio.data.waveform")

# Tokenize target text.
Expand All @@ -288,11 +260,11 @@ def normalize(waveform: Tensor) -> Tensor:
builder.map(collater, num_parallel_calls=npc)

# Return only the first `max_num_batches`.
if max_num_batches is not None:
builder.take(max_num_batches)
if options.max_num_batches is not None:
builder.take(options.max_num_batches)

# Prefetch `num_prefetch` batches in background.
builder.prefetch(num_prefetch)
builder.prefetch(options.num_prefetch)

# Wrap examples with `Seq2SeqBatch`.
def to_batch(example: dict[str, Any]) -> Seq2SeqBatch:
Expand Down Expand Up @@ -320,10 +292,10 @@ def to_batch(example: dict[str, Any]) -> Seq2SeqBatch:
self._name,
pipeline,
gang,
num_accumulate=num_accumulate,
drop_remainder=drop_remainder,
sync_batches=sync_batches,
sync_mode=sync_mode,
num_accumulate=options.num_accumulate,
drop_remainder=options.drop_remainder,
sync_batches=options.sync_batches,
sync_mode=options.sync_mode,
)

def _retrieve_data_directory(self, split: str) -> Path:
Expand Down
29 changes: 0 additions & 29 deletions src/fairseq2/datasets/batching.py

This file was deleted.

90 changes: 90 additions & 0 deletions src/fairseq2/datasets/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections.abc import MutableMapping
from dataclasses import dataclass, field
from typing import TypeAlias

from fairseq2.datasets.data_reader import SyncMode


@dataclass
class StaticBatching:
"""Specifies batching where each batch has the same number of examples."""

batch_size: int
"""The number of examples in each batch."""


@dataclass
class LengthBatching:
"""Specifies batching where each batch has a maximum number of elements."""

max_num_elements: int
"""The maximum number of elements (e.g. tokens) in each batch."""


Batching: TypeAlias = StaticBatching | LengthBatching


@dataclass(kw_only=True)
class DataReadOptions:
example_shuffle_window: int = 1
"""
The size of the sliding window for shuffling examples. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by loading the
entire dataset.
"""

batch_shuffle_window: int = 1
"""
The size of the sliding window for shuffling batches. If ``1``, no
shuffling is performed; if ``0``, true shuffling is performed by loading the
entire dataset.
"""

drop_remainder: bool = False
"""
If ``True``, drops the last set of batches if they have in total fewer
examples than requested.
"""

sync_batches: bool = True
"""
If ``True``, ensures that each process in ``gang`` reads the same number of
batches. Typically used when the amount of data to be read can vary per
process (e.g. due to unbalanced sharding or non-static batching) and it is
critical for each process to iterate over the same number of batches (e.g.
during training).
"""

sync_mode: SyncMode = "until_first"
"""
If ``until_first``, stops iteration on all ranks when one of the ranks
reaches its end of data. If ``until_last``, stops iteration when all ranks
reach their end of data; ranks that have already reached their end of data
will return an empty list of batches.
"""

max_num_batches: int | None = None
"""The maximum number of batches to return."""

num_accumulate: int = 1
"""
The number of batches to accumulate in each iteration. Typically used with
gradient accumulation during training.
"""

num_prefetch: int = 1
"""The number of batches to prefetch in background."""

seed: int = 2
"""The seed to initialize the random number generators used internally."""

extras: MutableMapping[str, object] = field(default_factory=dict)
"""The reader-specific extra options."""
Loading
Loading