Skip to content

Commit

Permalink
Revise datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Jan 5, 2025
1 parent 0d9ad31 commit 90baffc
Show file tree
Hide file tree
Showing 24 changed files with 590 additions and 611 deletions.
1 change: 0 additions & 1 deletion src/fairseq2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# isort: split

import fairseq2.datasets
import fairseq2.models

# isort: split
Expand Down
3 changes: 2 additions & 1 deletion src/fairseq2/chatbots/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog, ChatMessage
from fairseq2.chatbots.handler import ChatbotHandler
from fairseq2.data.text import LLaMA3Tokenizer, TextTokenEncoder, TextTokenizer
from fairseq2.data.text import TextTokenEncoder, TextTokenizer
from fairseq2.data.text.tokenizers.llama import LLaMA3Tokenizer
from fairseq2.generation import SequenceGenerator
from fairseq2.nn.utils.module import infer_device

Expand Down
84 changes: 5 additions & 79 deletions src/fairseq2/data/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,82 +11,8 @@
from fairseq2.data.text.converters import StrToTensorConverter as StrToTensorConverter
from fairseq2.data.text.text_reader import LineEnding as LineEnding
from fairseq2.data.text.text_reader import read_text as read_text
from fairseq2.data.text.tokenizers.char_tokenizer import (
CHAR_TOKENIZER_FAMILY as CHAR_TOKENIZER_FAMILY,
)
from fairseq2.data.text.tokenizers.handler import (
StandardTextTokenizerHandler as StandardTextTokenizerHandler,
)
from fairseq2.data.text.tokenizers.handler import (
TextTokenizerHandler as TextTokenizerHandler,
)
from fairseq2.data.text.tokenizers.handler import (
TextTokenizerLoader as TextTokenizerLoader,
)
from fairseq2.data.text.tokenizers.handler import (
TextTokenizerNotFoundError as TextTokenizerNotFoundError,
)
from fairseq2.data.text.tokenizers.handler import (
get_text_tokenizer_family as get_text_tokenizer_family,
)
from fairseq2.data.text.tokenizers.llama import (
LLAMA_TOKENIZER_FAMILY as LLAMA_TOKENIZER_FAMILY,
)
from fairseq2.data.text.tokenizers.llama import LLaMA3Tokenizer as LLaMA3Tokenizer
from fairseq2.data.text.tokenizers.mistral import (
MISTRAL_TOKENIZER_FAMILY as MISTRAL_TOKENIZER_FAMILY,
)
from fairseq2.data.text.tokenizers.nllb import (
NLLB_TOKENIZER_FAMILY as NLLB_TOKENIZER_FAMILY,
)
from fairseq2.data.text.tokenizers.nllb import NllbTokenizer as NllbTokenizer
from fairseq2.data.text.tokenizers.ref import (
resolve_text_tokenizer_reference as resolve_text_tokenizer_reference,
)
from fairseq2.data.text.tokenizers.s2t_transformer import (
S2T_TRANSFORMER_TOKENIZER_FAMILY as S2T_TRANSFORMER_TOKENIZER_FAMILY,
)
from fairseq2.data.text.tokenizers.s2t_transformer import (
S2TTransformerTokenizer as S2TTransformerTokenizer,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
BasicSentencePieceTokenizer as BasicSentencePieceTokenizer,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
RawSentencePieceTokenizer as RawSentencePieceTokenizer,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
SentencePieceDecoder as SentencePieceDecoder,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
SentencePieceEncoder as SentencePieceEncoder,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
SentencePieceModel as SentencePieceModel,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
SentencePieceTokenizer as SentencePieceTokenizer,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
load_basic_sentencepiece as load_basic_sentencepiece,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
load_raw_sentencepiece as load_raw_sentencepiece,
)
from fairseq2.data.text.tokenizers.sentencepiece import (
vocab_info_from_sentencepiece as vocab_info_from_sentencepiece,
)
from fairseq2.data.text.tokenizers.static import (
load_text_tokenizer as load_text_tokenizer,
)
from fairseq2.data.text.tokenizers.tiktoken import TiktokenDecoder as TiktokenDecoder
from fairseq2.data.text.tokenizers.tiktoken import TiktokenEncoder as TiktokenEncoder
from fairseq2.data.text.tokenizers.tiktoken import (
TiktokenTokenizer as TiktokenTokenizer,
)
from fairseq2.data.text.tokenizers.tokenizer import (
AbstractTextTokenizer as AbstractTextTokenizer,
)
from fairseq2.data.text.tokenizers.tokenizer import TextTokenDecoder as TextTokenDecoder
from fairseq2.data.text.tokenizers.tokenizer import TextTokenEncoder as TextTokenEncoder
from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer as TextTokenizer
from fairseq2.data.text.tokenizers import AbstractTextTokenizer as AbstractTextTokenizer
from fairseq2.data.text.tokenizers import TextTokenDecoder as TextTokenDecoder
from fairseq2.data.text.tokenizers import TextTokenEncoder as TextTokenEncoder
from fairseq2.data.text.tokenizers import TextTokenizer as TextTokenizer
from fairseq2.data.text.tokenizers import load_text_tokenizer as load_text_tokenizer
35 changes: 35 additions & 0 deletions src/fairseq2/data/text/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.data.text.tokenizers.handler import (
StandardTextTokenizerHandler as StandardTextTokenizerHandler,
)
from fairseq2.data.text.tokenizers.handler import (
TextTokenizerHandler as TextTokenizerHandler,
)
from fairseq2.data.text.tokenizers.handler import (
TextTokenizerLoader as TextTokenizerLoader,
)
from fairseq2.data.text.tokenizers.handler import (
TextTokenizerNotFoundError as TextTokenizerNotFoundError,
)
from fairseq2.data.text.tokenizers.handler import (
get_text_tokenizer_family as get_text_tokenizer_family,
)
from fairseq2.data.text.tokenizers.ref import (
resolve_text_tokenizer_reference as resolve_text_tokenizer_reference,
)
from fairseq2.data.text.tokenizers.static import (
load_text_tokenizer as load_text_tokenizer,
)
from fairseq2.data.text.tokenizers.tokenizer import (
AbstractTextTokenizer as AbstractTextTokenizer,
)
from fairseq2.data.text.tokenizers.tokenizer import TextTokenDecoder as TextTokenDecoder
from fairseq2.data.text.tokenizers.tokenizer import TextTokenEncoder as TextTokenEncoder
from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer as TextTokenizer
21 changes: 8 additions & 13 deletions src/fairseq2/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
from fairseq2.datasets.batching import StaticBatching as StaticBatching
from fairseq2.datasets.data_reader import DataPipelineReader as DataPipelineReader
from fairseq2.datasets.data_reader import DataReader as DataReader
from fairseq2.datasets.data_reader import SyncMode as SyncMode
from fairseq2.datasets.error import DataReadError as DataReadError
from fairseq2.datasets.error import DatasetError as DatasetError
from fairseq2.datasets.loader import AbstractDatasetLoader as AbstractDatasetLoader
from fairseq2.datasets.loader import DatasetLoader as DatasetLoader
from fairseq2.datasets.loader import DelegatingDatasetLoader as DelegatingDatasetLoader
from fairseq2.datasets.loader import get_dataset_family as get_dataset_family
from fairseq2.datasets.loader import is_dataset_card as is_dataset_card

# isort: split

import fairseq2.datasets.asr
import fairseq2.datasets.instruction
import fairseq2.datasets.parallel_text
import fairseq2.datasets.speech
import fairseq2.datasets.text
from fairseq2.datasets.handler import DatasetHandler as DatasetHandler
from fairseq2.datasets.handler import DatasetLoader as DatasetLoader
from fairseq2.datasets.handler import DatasetNotFoundError as DatasetNotFoundError
from fairseq2.datasets.handler import StandardDatasetHandler as StandardDatasetHandler
from fairseq2.datasets.handler import get_dataset_family as get_dataset_family
from fairseq2.datasets.static import load_dataset as load_dataset
74 changes: 36 additions & 38 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Literal, cast, final
from typing import Any, Final, cast, final

import torch
from torch import Tensor
from torch.nn.functional import layer_norm
from typing_extensions import override

from fairseq2.assets import AssetCard, AssetError
from fairseq2.assets import AssetCard
from fairseq2.data import (
CollateOptionsOverride,
Collater,
Expand All @@ -29,9 +29,10 @@
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader, SyncMode
from fairseq2.datasets.error import DatasetError
from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader
from fairseq2.datasets.static import load_dataset
from fairseq2.error import NotSupportedError
from fairseq2.gang import Gang
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.nn.padding import get_seqs_and_padding_mask
Expand All @@ -57,14 +58,14 @@ def create_reader(
batch_shuffle_window: int = 1,
drop_remainder: bool = False,
sync_batches: bool = True,
sync_mode: Literal["until_first", "until_last"] = "until_first",
sync_mode: SyncMode = "until_first",
max_num_batches: int | None = None,
num_accumulate: int = 1,
num_prefetch: int = 1,
seed: int = 2,
**extras: Any,
) -> DataReader[Seq2SeqBatch]:
"""Create a dataset reader.
"""Make a dataset reader.
:param split:
The split to read.
Expand Down Expand Up @@ -124,47 +125,51 @@ def splits(self) -> set[str]:
"""Return the set of splits."""


load_asr_dataset = DelegatingDatasetLoader[AsrDataset]()


# TODO: FIX, INFER
npc = 10


GENERIC_ASR_DATASET_FAMILY: Final = "generic_asr"


# TODO: Work in progress!
@final
class GenericAsrDataset(AsrDataset):
"""Represents a generic manifest-based ASR dataset."""

_name: str
_manifest_dir: Path
_splits: set[str]

def __init__(self, manifest_dir: Path, splits: set[str]) -> None:
def __init__(self, name: str, manifest_dir: Path, splits: set[str]) -> None:
"""
:param manifest_dir:
The directory under which the manifest files resides.
:param splits:
The available splits.
"""
self._name = name
self._manifest_dir = manifest_dir
self._splits = splits

@classmethod
def from_path(cls, path: Path) -> GenericAsrDataset:
"""Load a :class:`GenericAsrDataset` from ``path``."""
@staticmethod
def from_path(path: Path, name: str | None = None) -> GenericAsrDataset:
if name is None:
name = f"path:{path.name}"

path = path.expanduser().resolve()

if not path.is_dir():
return GenericAsrDataset(manifest_dir=path.parent, splits={path.stem})
return GenericAsrDataset(name, manifest_dir=path.parent, splits={path.stem})

try:
splits = {f.stem for f in path.glob("*.tsv")}
except OSError as ex:
raise RuntimeError(
"The splits cannot be determined. See nested exception for details."
raise DatasetError(
f"The splits under the '{path}' directory cannot be determined. See the nested exception for details."
) from ex

return GenericAsrDataset(path, splits)
return GenericAsrDataset(name, path, splits)

@override
def create_reader(
Expand All @@ -182,7 +187,7 @@ def create_reader(
batch_shuffle_window: int = 1,
drop_remainder: bool = False,
sync_batches: bool = True,
sync_mode: Literal["until_first", "until_last"] = "until_first",
sync_mode: SyncMode = "until_first",
max_num_batches: int | None = None,
num_accumulate: int = 1,
num_prefetch: int = 1,
Expand All @@ -196,8 +201,10 @@ def create_reader(
audio files.
"""
if split not in self._splits:
s = ", ".join(sorted(self._splits))

raise ValueError(
f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits))}"
f"`split` must be one of the following splits, but is '{split}' instead: {s}"
)

audio_dir = self._retrieve_data_directory(split)
Expand Down Expand Up @@ -234,7 +241,7 @@ def create_reader(
)
elif isinstance(batching, StaticBatching):
# Filter out out-of-range audios.
def skip(example: dict[str, Any]) -> bool:
def skip(example: dict[str, object]) -> bool:
audio_len = cast(int, example["audio_size"])

return audio_len >= min_audio_len and audio_len <= max_audio_len
Expand All @@ -244,7 +251,7 @@ def skip(example: dict[str, Any]) -> bool:
# Bucket `batch_size` examples.
builder.bucket(batching.batch_size, drop_remainder=drop_remainder)
else:
raise RuntimeError(f"`{batching}` is not supported.")
raise NotSupportedError(f"`{batching}` is not supported.")

# Shuffle buckets.
if batch_shuffle_window != 1:
Expand Down Expand Up @@ -318,6 +325,7 @@ def to_batch(example: dict[str, Any]) -> Seq2SeqBatch:
pipeline = builder.map(to_batch).and_return()

return DataPipelineReader[Seq2SeqBatch](
self._name,
pipeline,
gang,
num_accumulate=num_accumulate,
Expand All @@ -334,14 +342,15 @@ def _retrieve_data_directory(self, split: str) -> Path:
line = fp.readline().rstrip()
except OSError as ex:
raise DatasetError(
f"{manifest_file} cannot be read. See nested exception for details."
self._name,
f"The {manifest_file} manifest file cannot be read. See the nested exception for details.",
) from ex

try:
return Path(line)
except ValueError:
raise DatasetError(
f"The first line of {manifest_file} must point to a data directory."
f"The first line of the '{manifest_file}' manifest file must point to a data directory."
) from None

def _read_manifest(self, split: str) -> DataPipelineBuilder:
Expand Down Expand Up @@ -381,18 +390,7 @@ def splits(self) -> set[str]:
return self._splits


@final
class GenericAsrDatasetLoader(AbstractDatasetLoader[GenericAsrDataset]):
@override
def _load(self, path: Path, card: AssetCard) -> GenericAsrDataset:
try:
return GenericAsrDataset.from_path(path)
except RuntimeError as ex:
raise AssetError(
f"{card.name} cannot be loaded. See nested exception for details."
) from ex


load_generic_asr_dataset = GenericAsrDatasetLoader()

load_asr_dataset.register("generic_asr", load_generic_asr_dataset)
def load_asr_dataset(
name_or_card: str | AssetCard, *, force: bool = False
) -> AsrDataset:
return load_dataset(name_or_card, AsrDataset, force=force)
Loading

0 comments on commit 90baffc

Please sign in to comment.