diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 88ea999e1..b8863fd41 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,4 @@ * @cbalioglu src/fairseq2/data/parquet/ @artemru tests/integration/parquet/ @artemru +doc/ @zyaoj diff --git a/.github/workflows/_build_doc.yaml b/.github/workflows/_build_doc.yaml index 64cbc434f..d7491e730 100644 --- a/.github/workflows/_build_doc.yaml +++ b/.github/workflows/_build_doc.yaml @@ -9,31 +9,32 @@ on: inputs: torch: type: string - default: '2.3.0' + default: '2.5.1' py: type: string - default: '3.11' + default: '3.12' version_override: type: string default: '' -env: - ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true - jobs: build: name: Build runs-on: labels: 4-core-ubuntu container: - image: ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu + image: ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:3-cpu defaults: run: shell: bash steps: + - name: Install Git LFS + run: | + yum install -y git-lfs - name: Check-out the repository uses: actions/checkout@v3 with: + lfs: true submodules: recursive - name: Create the Python virtual environment run: | @@ -57,7 +58,7 @@ jobs: - name: Configure fairseq2n working-directory: native run: | - cmake -GNinja -B build + scl enable gcc-toolset-11 "cmake -GNinja -B build" - name: Install fairseq2n run: | pip install --editable native/python @@ -72,7 +73,7 @@ jobs: run: | cp VERSION doc/build/html - name: Upload documentation to staging - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: doc path: doc/build/html/ diff --git a/.github/workflows/_build_wheel-linux.yaml b/.github/workflows/_build_wheel-linux.yaml index 943453e60..18d7a12cb 100644 --- a/.github/workflows/_build_wheel-linux.yaml +++ b/.github/workflows/_build_wheel-linux.yaml @@ -38,16 +38,13 @@ on: type: boolean default: false -env: - ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true - jobs: build: name: Build runs-on: labels: 4-core-ubuntu container: - image: ghcr.io/facebookresearch/fairseq2-ci-manylinux_${{ inputs.arch }}:2-${{ inputs.variant }} + image: ghcr.io/facebookresearch/fairseq2-ci-manylinux_${{ inputs.arch }}:3-${{ inputs.variant }} defaults: run: shell: bash @@ -81,9 +78,9 @@ jobs: # If the version has already a local label, append the variant. if [[ $version == *+* ]]; then - tools/set-project-version.sh $version.$VARIANT + tools/set-project-version.sh --native-only $version.$VARIANT else - tools/set-project-version.sh $version+$VARIANT + tools/set-project-version.sh --native-only $version+$VARIANT fi - name: Build fairseq2n working-directory: native @@ -112,18 +109,20 @@ jobs: lto=OFF fi + scl enable gcc-toolset-11 - < [!WARNING] diff --git a/README.md b/README.md index e85cbc488..81e207a50 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

-
+

# fairseq2: FAIR Sequence Modeling Toolkit 2 @@ -16,7 +16,7 @@ other content generation tasks. It is also the successor of [fairseq](https://github.com/facebookresearch/fairseq). ## Getting Started -Coming soon... +Visit our [documentation website](https://facebookresearch.github.io/fairseq2/stable/). For recent changes, you can check out our [changelog](CHANGELOG.md). @@ -24,13 +24,11 @@ For recent changes, you can check out our [changelog](CHANGELOG.md). ## Models As of today, the following models are available in fairseq2: - * [LLaMA](src/fairseq2/models/llama) - * [LLaMA 2](src/fairseq2/models/llama) - * [LLaMA 3](src/fairseq2/models/llama) - * [LLaMA 3.1](src/fairseq2/models/llama) + * [LLaMA 1 to 3.3](src/fairseq2/models/llama) * [Mistral 7B](src/fairseq2/mistral) * [NLLB-200](src/fairseq2/models/nllb) * [S2T Transformer + Conformer](src/fairseq2/models/s2t_transformer) + * [V-JEPA](src/fairseq2/models/jepa) * [w2v-BERT](src/fairseq2/models/w2vbert) * [wav2vec 2.0](src/fairseq2/models/wav2vec2) * [wav2vec 2.0 ASR](src/fairseq2/models/wav2vec2/asr) @@ -38,6 +36,7 @@ As of today, the following models are available in fairseq2: fairseq2 is also used by various external projects such as: * [Seamless Communication](https://github.com/facebookresearch/seamless_communication) + * [Large Concept Model](https://github.com/facebookresearch/large_concept_model) * [SONAR](https://github.com/facebookresearch/SONAR) @@ -92,20 +91,40 @@ matrix shows the supported combinations. HEAD - 2.3.0 - >=3.8, <=3.11 - cpu, cu118, cu121 + 2.5.0, 2.5.1 + >=3.10, <=3.12 + cpu, cu118, cu121, cu124 x86_64 - 2.2.0, 2.2.1, 2.2.2 - >=3.8, <=3.11 + 2.4.0, 2.4.1 + >=3.10, <=3.12 + cpu, cu118, cu121, cu124 + x86_64 + + + 2.3.0, 2.3.1 + >=3.10, <=3.12 cpu, cu118, cu121 x86_64 + - 2.1.0, 2.1.1, 2.1.2 - >=3.8, <=3.11 + 0.3.0 + 2.5.0, 2.5.1 + >=3.10, <=3.12 + cpu, cu118, cu121, cu124 + x86_64 + + + 2.4.0, 2.4.1 + >=3.10, <=3.12 + cpu, cu118, cu121, cu124 + x86_64 + + + 2.3.0, 2.3.1 + >=3.10, <=3.12 cpu, cu118, cu121 x86_64 @@ -135,12 +154,12 @@ matrix shows the supported combinations. To install a specific combination, first follow the installation instructions on [pytorch.org](https://pytorch.org/get-started/locally) for the desired PyTorch -version, and then use the following command (shown for PyTorch `2.3.0` and -variant `cu118`): +version, and then use the following command (shown for PyTorch `2.5.1` and +variant `cu124`): ```sh pip install fairseq2\ - --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.3.0/cu118 + --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.5.1/cu124 ``` > [!WARNING] @@ -155,12 +174,12 @@ pip install fairseq2\ For Linux, we also host nightly builds on FAIR's package repository. The supported variants are identical to the ones listed in *Variants* above. Once you have installed the desired PyTorch version, you can use the following -command to install the corresponding nightly package (shown for PyTorch `2.3.0` -and variant `cu118`): +command to install the corresponding nightly package (shown for PyTorch `2.5.1` +and variant `cu124`): ```sh pip install fairseq2\ - --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.3.0/cu118 + --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.5.1/cu124 ``` @@ -202,9 +221,9 @@ the supported combinations. - HEAD - 2.3.0 - >=3.8, <=3.11 + 0.3.0 + 2.5.1 + >=3.10, <=3.12 arm64 @@ -212,11 +231,11 @@ the supported combinations. To install a specific combination, first follow the installation instructions on [pytorch.org](https://pytorch.org/get-started/locally) for the desired PyTorch -version, and then use the following command (shown for PyTorch `2.3.0`): +version, and then use the following command (shown for PyTorch `2.5.1`): ```sh pip install fairseq2\ - --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.3.0/cpu + --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.5.1/cpu ``` > [!WARNING] @@ -231,11 +250,11 @@ pip install fairseq2\ For macOS, we also host nightly builds on FAIR's package repository. The supported variants are identical to the ones listed in *Variants* above. Once you have installed the desired PyTorch version, you can use the following -command to install the corresponding nightly package (shown for PyTorch `2.3.0`): +command to install the corresponding nightly package (shown for PyTorch `2.5.1`): ```sh pip install fairseq2\ - --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.3.0/cpu + --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.5.1/cpu ``` diff --git a/VERSION b/VERSION index 13668bbb9..5acf154b8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.0.dev0 +0.4.0.dev0 diff --git a/bibliography.bib b/bibliography.bib index 2e275df7c..e8e0a8daf 100644 --- a/bibliography.bib +++ b/bibliography.bib @@ -1,10 +1,11 @@ -@misc{https://doi.org/10.48550/arxiv.1608.03983, - title={SGDR: Stochastic Gradient Descent with Warm Restarts}, - author={Ilya Loshchilov and Frank Hutter}, - year={2017}, - eprint={1608.03983}, +@misc{https://doi.org/10.48550/arxiv.1603.09382, + title={Deep Networks with Stochastic Depth}, + author={Gao Huang and Yu Sun and Zhuang Liu and Daniel Sedra and Kilian Weinberger}, + year={2016}, + eprint={1603.09382}, archivePrefix={arXiv}, - primaryClass={cs.LG} + primaryClass={cs.LG}, + url={https://arxiv.org/abs/1603.09382}, } @misc{https://doi.org/10.48550/arxiv.1607.06450, @@ -25,6 +26,25 @@ @misc{https://doi.org/10.48550/arxiv.1609.03499 primaryClass={cs.SD} } +@misc{https://doi.org/10.48550/arxiv.1608.03983, + title={SGDR: Stochastic Gradient Descent with Warm Restarts}, + author={Ilya Loshchilov and Frank Hutter}, + year={2017}, + eprint={1608.03983}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} + +@misc{https://doi.org/10.48550/arXiv.1612.08083, + title={Language Modeling with Gated Convolutional Networks}, + author={Yann N. Dauphin and Angela Fan and Michael Auli and David Grangier}, + year={2017}, + eprint={1612.08083}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/1612.08083}, +} + @misc{https://doi.org/10.48550/arxiv.1706.03762, doi={10.48550/ARXIV.1706.03762}, url={https://arxiv.org/abs/1706.03762}, @@ -193,15 +213,14 @@ @misc{https://doi.org/10.48550/arxiv.2108.12409 copyright={arXiv.org perpetual, non-exclusive license} } -@misc{https://doi.org/10.48550/arxiv.2207.04672, - doi={10.48550/arxiv.2207.04672}, - url={https://arxiv.org/abs/2207.04672}, - title={No Language Left Behind: Scaling Human-Centered Machine Translation}, - author={NLLB Team and Marta R. Costa-jussà and James Cross and Onur Çelebi and Maha Elbayad and Kenneth Heafield and Kevin Heffernan and Elahe Kalbassi and Janice Lam and Daniel Licht and Jean Maillard and Anna Sun and Skyler Wang and Guillaume Wenzek and Al Youngblood and Bapi Akula and Loic Barrault and Gabriel Mejia Gonzalez and Prangthip Hansanti and John Hoffman and Semarley Jarrett and Kaushik Ram Sadagopan and Dirk Rowe and Shannon Spruit and Chau Tran and Pierre Andrews and Necip Fazil Ayan and Shruti Bhosale and Sergey Edunov and Angela Fan and Cynthia Gao and Vedanuj Goswami and Francisco Guzmán and Philipp Koehn and Alexandre Mourachko and Christophe Ropers and Safiyyah Saleem and Holger Schwenk and Jeff Wang}, - year={2022}, - eprint={2207.04672}, - archivePrefix={arXiv}, - primaryClass={cs.CL} +@misc{https://doi.org/10.48550/arXiv.2010.11929, + title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, + author={Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby}, + year={2021}, + eprint={2010.11929}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2010.11929}, } @misc{https://doi.org/10.48550/arxiv.2110.09456, @@ -215,6 +234,17 @@ @misc{https://doi.org/10.48550/arxiv.2110.09456 copyright={arXiv.org perpetual, non-exclusive license} } +@misc{https://doi.org/10.48550/arxiv.2207.04672, + doi={10.48550/arxiv.2207.04672}, + url={https://arxiv.org/abs/2207.04672}, + title={No Language Left Behind: Scaling Human-Centered Machine Translation}, + author={NLLB Team and Marta R. Costa-jussà and James Cross and Onur Çelebi and Maha Elbayad and Kenneth Heafield and Kevin Heffernan and Elahe Kalbassi and Janice Lam and Daniel Licht and Jean Maillard and Anna Sun and Skyler Wang and Guillaume Wenzek and Al Youngblood and Bapi Akula and Loic Barrault and Gabriel Mejia Gonzalez and Prangthip Hansanti and John Hoffman and Semarley Jarrett and Kaushik Ram Sadagopan and Dirk Rowe and Shannon Spruit and Chau Tran and Pierre Andrews and Necip Fazil Ayan and Shruti Bhosale and Sergey Edunov and Angela Fan and Cynthia Gao and Vedanuj Goswami and Francisco Guzmán and Philipp Koehn and Alexandre Mourachko and Christophe Ropers and Safiyyah Saleem and Holger Schwenk and Jeff Wang}, + year={2022}, + eprint={2207.04672}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} + @misc{https://doi.org/10.48550/arxiv.2212.08055, title={UnitY: Two-pass Direct Speech-to-speech Translation with Discrete Units}, author={Hirofumi Inaguma and Sravya Popuri and Ilia Kulikov and Peng-Jen Chen and Changhan Wang and Yu-An Chung and Yun Tang and Ann Lee and Shinji Watanabe and Juan Pino}, @@ -259,3 +289,23 @@ @misc{https://doi.org/10.48550/arXiv.2310.06825 archivePrefix={arXiv}, primaryClass={cs.CL} } + +@misc{https://doi.org/10.48550/arXiv.2301.08243, + title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture}, + author={Mahmoud Assran and Quentin Duval and Ishan Misra and Piotr Bojanowski and Pascal Vincent and Michael Rabbat and Yann LeCun and Nicolas Ballas}, + year={2023}, + eprint={2301.08243}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2301.08243}, +} + +@misc{https://doi.org/10.48550/arXiv.2404.08471, + title={Revisiting Feature Prediction for Learning Visual Representations from Video}, + author={Adrien Bardes and Quentin Garrido and Jean Ponce and Xinlei Chen and Michael Rabbat and Yann LeCun and Mahmoud Assran and Nicolas Ballas}, + year={2024}, + eprint={2404.08471}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2404.08471}, +} diff --git a/ci/docker/build-manylinux-images.sh b/ci/docker/build-manylinux-images.sh index 45dd7493b..ce637271c 100755 --- a/ci/docker/build-manylinux-images.sh +++ b/ci/docker/build-manylinux-images.sh @@ -12,9 +12,9 @@ repo=ghcr.io/facebookresearch arch=x86_64 -version=2 +version=3 -declare -a variants=(cpu cu116 cu117 cu118 cu121) +declare -a variants=(cpu cu118 cu121 cu124) for variant in "${variants[@]}"; do docker build\ diff --git a/ci/docker/manylinux_x86_64/Dockerfile.cpu b/ci/docker/manylinux_x86_64/Dockerfile.cpu index b9985d7f5..73bed677f 100644 --- a/ci/docker/manylinux_x86_64/Dockerfile.cpu +++ b/ci/docker/manylinux_x86_64/Dockerfile.cpu @@ -4,12 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -FROM quay.io/pypa/manylinux2014_x86_64 +FROM quay.io/pypa/manylinux_2_28_x86_64 # Install system dependencies. -RUN yum --assumeyes install\ - devtoolset-10-lib{asan,lsan,ubsan,tsan}-devel libsndfile-devel &&\ - yum clean all +RUN dnf --assumeyes install\ + gcc-toolset-11\ + gcc-toolset-11-lib{asan,lsan,ubsan,tsan}-devel\ + libsndfile-devel &&\ + dnf clean all # Install Ninja. RUN pipx install --pip-args=--no-cache-dir ninja @@ -17,7 +19,8 @@ RUN pipx install --pip-args=--no-cache-dir ninja # Install LLVM. COPY build-scripts/install-llvm.sh /build-scripts/ -RUN /build-scripts/install-llvm.sh && rm -rf /build-scripts +RUN scl enable gcc-toolset-11 /build-scripts/install-llvm.sh &&\ + rm -rf /build-scripts # Path to sanitizer libs. Used by the CI tests. ENV LIBASAN=/usr/lib64/libasan.so.6 diff --git a/ci/docker/manylinux_x86_64/Dockerfile.cu117 b/ci/docker/manylinux_x86_64/Dockerfile.cu117 deleted file mode 100644 index 4b35f65d9..000000000 --- a/ci/docker/manylinux_x86_64/Dockerfile.cu117 +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu - -# Install CUDA. -COPY build-scripts/install-cuda-11.7.sh /build-scripts/ - -RUN /build-scripts/install-cuda-11.7.sh && rm -rf /build-scripts - -ENV PATH=/usr/local/cuda-11.7/bin:$PATH diff --git a/ci/docker/manylinux_x86_64/Dockerfile.cu118 b/ci/docker/manylinux_x86_64/Dockerfile.cu118 index 57f23c8ee..e7473a8db 100644 --- a/ci/docker/manylinux_x86_64/Dockerfile.cu118 +++ b/ci/docker/manylinux_x86_64/Dockerfile.cu118 @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu +FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:3-cpu # Install CUDA. COPY build-scripts/install-cuda-11.8.sh /build-scripts/ diff --git a/ci/docker/manylinux_x86_64/Dockerfile.cu121 b/ci/docker/manylinux_x86_64/Dockerfile.cu121 index 0037d2e61..7a86cbd14 100644 --- a/ci/docker/manylinux_x86_64/Dockerfile.cu121 +++ b/ci/docker/manylinux_x86_64/Dockerfile.cu121 @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu +FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:3-cpu # Install CUDA. COPY build-scripts/install-cuda-12.1.sh /build-scripts/ diff --git a/ci/docker/manylinux_x86_64/Dockerfile.cu116 b/ci/docker/manylinux_x86_64/Dockerfile.cu124 similarity index 50% rename from ci/docker/manylinux_x86_64/Dockerfile.cu116 rename to ci/docker/manylinux_x86_64/Dockerfile.cu124 index c5dd28048..4a56d038f 100644 --- a/ci/docker/manylinux_x86_64/Dockerfile.cu116 +++ b/ci/docker/manylinux_x86_64/Dockerfile.cu124 @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:2-cpu +FROM ghcr.io/facebookresearch/fairseq2-ci-manylinux_x86_64:3-cpu # Install CUDA. -COPY build-scripts/install-cuda-11.6.sh /build-scripts/ +COPY build-scripts/install-cuda-12.4.sh /build-scripts/ -RUN /build-scripts/install-cuda-11.6.sh && rm -rf /build-scripts +RUN /build-scripts/install-cuda-12.4.sh && rm -rf /build-scripts -ENV PATH=/usr/local/cuda-11.6/bin:$PATH +ENV PATH=/usr/local/cuda-12.4/bin:$PATH diff --git a/ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.7.sh b/ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.7.sh deleted file mode 100755 index 9593afe28..000000000 --- a/ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.7.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -set -eo pipefail - -curl --location --fail --output cuda.run\ - https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run - -sh cuda.run --silent --toolkit --override --no-man-page - -rm cuda.run - -# We don't need Nsight. -rm -rf /usr/local/cuda-11.7/nsight* - -# Add CUDA libraries to the lookup cache of the dynamic linker. -ldconfig diff --git a/ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.6.sh b/ci/docker/manylinux_x86_64/build-scripts/install-cuda-12.4.sh similarity index 75% rename from ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.6.sh rename to ci/docker/manylinux_x86_64/build-scripts/install-cuda-12.4.sh index 0e056e411..9e9af4b99 100755 --- a/ci/docker/manylinux_x86_64/build-scripts/install-cuda-11.6.sh +++ b/ci/docker/manylinux_x86_64/build-scripts/install-cuda-12.4.sh @@ -9,14 +9,14 @@ set -eo pipefail curl --location --fail --output cuda.run\ - https://developer.download.nvidia.com/compute/cuda/11.6.0/local_installers/cuda_11.6.0_510.39.01_linux.run + https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run sh cuda.run --silent --toolkit --override --no-man-page rm cuda.run # We don't need Nsight. -rm -rf /usr/local/cuda-11.6/nsight* +rm -rf /usr/local/cuda-12.4/nsight* # Add CUDA libraries to the lookup cache of the dynamic linker. ldconfig diff --git a/doc/.gitattributes b/doc/.gitattributes new file mode 100644 index 000000000..24a8e8793 --- /dev/null +++ b/doc/.gitattributes @@ -0,0 +1 @@ +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/doc/.gitignore b/doc/.gitignore index 1710e8418..907f27834 100644 --- a/doc/.gitignore +++ b/doc/.gitignore @@ -1,2 +1,3 @@ # Auto-Generated Sphinx Stub Files generated/ +build/ \ No newline at end of file diff --git a/doc/Makefile b/doc/Makefile index a9d264852..d0c3cbf10 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SOURCEDIR = +SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". diff --git a/doc/README.md b/doc/README.md new file mode 100644 index 000000000..3941c684d --- /dev/null +++ b/doc/README.md @@ -0,0 +1,23 @@ +# fairseq2 documents + +## Install dependencies + +Follow the installation instructions and install fairseq2 and fairseq2n. + +## Build the docs + +```bash +# Install dependencies. +pip install -r requirements.txt + +# Build the docs. +make clean +make html +``` + +## Open the docs with your browser + +```bash +python -m http.server -d build/html/ +``` +Launch your browser and open localhost:8000. diff --git a/doc/bibliography.rst b/doc/bibliography.rst deleted file mode 100644 index 0d21440d0..000000000 --- a/doc/bibliography.rst +++ /dev/null @@ -1,4 +0,0 @@ -Bibliography -============ - -.. bibliography:: diff --git a/doc/conf.py b/doc/conf.py deleted file mode 100644 index 0574ebd58..000000000 --- a/doc/conf.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import fairseq2n - -fairseq2n.DOC_MODE = True - -import fairseq2 - -# ------------------------------------------------------------ -# Project Information -# ------------------------------------------------------------ - -project = "fairseq2" - -version = fairseq2.__version__ - -release = fairseq2.__version__ - -author = "Fundamental AI Research (FAIR) at Meta" - -# ------------------------------------------------------------ -# General Configuration -# ------------------------------------------------------------ - -needs_sphinx = "5.0.0" - -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.autosectionlabel", - "sphinx.ext.autosummary", - "sphinx.ext.coverage", - "sphinx.ext.intersphinx", - "sphinx.ext.todo", - "sphinx.ext.viewcode", - "sphinxcontrib.bibtex", -] - -primary_domain = "py" - -highlight_language = "python3" - -autodoc_typehints = "description" -autodoc_typehints_format = "short" -autodoc_typehints_description_target = "documented_params" - -autosectionlabel_prefix_document = True - -todo_include_todos = True - -intersphinx_mapping = { - "python": ("https://docs.python.org/3/", None), - "torch": ("https://pytorch.org/docs/stable/", None), -} - -templates_path = ["templates"] - -bibtex_bibfiles = ["../bibliography.bib"] - -# ------------------------------------------------------------ -# HTML Output Options -# ------------------------------------------------------------ - -html_theme = "sphinx_rtd_theme" - -html_theme_options = { - "collapse_navigation": False, - "navigation_depth": 3, -} - -html_show_copyright = False - -html_static_path = ["static"] diff --git a/doc/index.rst b/doc/index.rst deleted file mode 100644 index d916ef9ae..000000000 --- a/doc/index.rst +++ /dev/null @@ -1,23 +0,0 @@ -:github_url: https://github.com/facebookresearch/fairseq2 - - -fairseq2 documentation -====================== - -fairseq2 is a sequence modeling toolkit that allows researchers and developers -to train custom models for translation, summarization, language modeling, and -other content generation tasks. - -.. toctree:: - :caption: fairseq2 Reference - :maxdepth: 1 - - reference/data - reference/asset - reference/all - -.. toctree:: - :maxdepth: 1 - :caption: Misc - - bibliography diff --git a/doc/make.bat b/doc/make.bat new file mode 100644 index 000000000..dc1312ab0 --- /dev/null +++ b/doc/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/doc/reference/abc.rst b/doc/reference/abc.rst deleted file mode 100644 index 1bca9203e..000000000 --- a/doc/reference/abc.rst +++ /dev/null @@ -1,11 +0,0 @@ -ABCs and Protocols -================== -.. body - -.. currentmodule:: fairseq2 - -.. autosummary:: - :toctree: generated/abc - :nosignatures: - - gang.Gang diff --git a/doc/reference/all.rst b/doc/reference/all.rst deleted file mode 100644 index c9b24f03f..000000000 --- a/doc/reference/all.rst +++ /dev/null @@ -1,24 +0,0 @@ -:tocdepth: 1 - -All -=== - -ABCs and Protocols ------------------- -.. include:: abc.rst - :start-after: .. body - -Classes -------- -.. include:: classes.rst - :start-after: .. body - -Enums ------ -.. include:: enums.rst - :start-after: .. body - -Functions ---------- -.. include:: functions.rst - :start-after: .. body diff --git a/doc/reference/asset.rst b/doc/reference/asset.rst deleted file mode 100644 index 2d185bbfe..000000000 --- a/doc/reference/asset.rst +++ /dev/null @@ -1,50 +0,0 @@ -fairseq2.assets -=============== -.. body - -.. currentmodule:: fairseq2.assets - -``fairseq2.asset`` provides API to load the different model using the "model cards" from different "stores". - - -.. autosummary:: - :toctree: generated/data - - AssetStore - AssetCard - AssetMetadataProvider - -Model store -~~~~~~~~~~~ - -A store is a place where all the model cards are stored. In fairseq2, a store is accessed via -`fairseq2.assets.AssetStore`. Multiple stores are allowed. By default, fairseq2 will look up the following stores: - -* System asset store: Cards that are shared by all users. By default, the system store is `/etc/fairseq2/assets`, - but this can be changed via the environment variable `FAIRSEQ2_ASSET_DIR` - -* User asset store: Cards that are only available to the user. By default, the user store is - `~/.config/fairseq2/assets`, but this can be changed via the environment variable `FAIRSEQ2_USER_ASSET_DIR` - -To register a new store, implement a :py:class:`fairseq2.assets.AssetMetadataProvider` and add them to -:py:class:`fairseq2.assets.asset_store`. Here is an example to register a new directory as a model store: - - from pathlib import Path - from fairseq2.assets import FileAssetMetadataProvider, asset_store - - my_dir = Path("/path/to/model_store") - asset_store.metadata_providers.append(FileAssetMetadataProvider(my_dir)) - - -Model card -~~~~~~~~~~~ - -A model card is a .YAML file that contains information about a model and instructs a -:py:class:`fairseq2.models.utils.generic_loaders.ModelLoader` on how to load the model into the memory. Each model card -must have 2 mandatory attributes: `name` and `checkpoint`. `name` will be used to identify the model card, and it must -be unique _across_ all -fairseq2 provides example cards for differen LLMs in -`fairseq2.assets.cards`. - -In fairseq2, a model card is accessed via :py:class:`fairseq2.assets.AssetCard`. Alternatively, one can call -`fairseq2.assets.AssetMetadataProvider.get_metadata(name: str)` to get the meta data of a given model card name. diff --git a/doc/reference/classes.rst b/doc/reference/classes.rst deleted file mode 100644 index f52a4d103..000000000 --- a/doc/reference/classes.rst +++ /dev/null @@ -1,14 +0,0 @@ -Classes -======= -.. body - -.. currentmodule:: fairseq2 - -.. autosummary:: - :toctree: generated/classes - :nosignatures: - - optim.lr_scheduler.CosineAnnealingLR - optim.lr_scheduler.MyleLR - optim.lr_scheduler.NoamLR - optim.lr_scheduler.PolynomialDecayLR diff --git a/doc/reference/data.rst b/doc/reference/data.rst deleted file mode 100644 index b9386e5fa..000000000 --- a/doc/reference/data.rst +++ /dev/null @@ -1,136 +0,0 @@ -fairseq2.data -============= -.. body - -.. currentmodule:: fairseq2.data - -``fairseq2.data`` provides a Python API to build a C++ :py:class:`DataPipeline`. - -The dataloader will be able to leverage several threads, -working around Python Global Interpreter Lock limitations, -and also providing better performance -than a pure Python dataloader. - -Building a :py:class:`DataPipeline` looks like this:: - - data = ( - text.read_text("file.tsv") - .map(lambda x: str(x.split("\t")[1]).lower()) - .filter(lambda x: len(x) < 10) - ) - -Functions to build a :py:class:`DataPipeline`: - - -.. autosummary:: - :toctree: generated/data - - DataPipeline - DataPipelineBuilder - - list_files - read_sequence - read_zipped_records - text.read_text - FileMapper - - Collater - CollateOptionsOverride - -Column syntax -~~~~~~~~~~~~~ - -The data items going through the pipeline don't have to be flat tensors, but can be tuples, or python dictionaries. -Several operators have a syntax to specify a specific column of the input data. -Notably the :py:func:`DataPipelineBuilder.map` operator -has a `selector` argument to choose the column to apply the function to. - -If the data item is a tuple, -then the selector ``"[3]"`` selects the third column. -If the data item is a dictionary, then ``"foo"`` will select the value corresponding to the key ``"foo"``. -You can nest selectors using ``.`` to separate key selectors, following a python-like syntax. -For a data item ``{"foo": [{"x": 1, "y": 2}, {"x": 3, "y": 4, "z": 5}], "bar": 6}``, -the selector ``"foo[1].y"`` referes to the value 4. - -Functions that accepts several selectors, -accept them as a comma separated list of selectors. -For example ``.map(lambda x: x * 10, selector="foo[1].y,bar")`` -will multiply the values 4 and 6 by 10, but leave others unmodified. - -Pseudo-infinite and Infinite Pipelines -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The :py:func:`DataPipeline.count` and :py:func:`DataPipeline.constant` static methods create pseudo-infinite pipelines. -When used with operators that combine multiple pipelines (e.g. :py:func:`DataPipeline.sample`, -:py:func:`DataPipeline.round_robin`, :py:func:`DataPipeline.zip`), -they will only yield examples as long as the other pipelines yield examples. - -For example:: - - from fairseq2.data import DataPipeline, read_sequence - - pipeline1 = DataPipeline.constant(0).and_return() - pipeline2 = read_sequence([1, 2, 3]).and_return() - - for example in DataPipeline.round_robin(pipeline1, pipeline2).and_return(): - print(example) - -only produces 0, 1, 0, 2, 0, 3. - -Infinite pipelines (pipelines created through :py:func:`DataPipelineBuilder.repeat` with no arguments) -do not exhibit this behavior; they will yield examples indefinitely even when combined with other pipelines. - -For example:: - - from fairseq2.data import DataPipeline, read_sequence - - pipeline1 = read_sequence([0]).repeat().and_return() - pipeline2 = read_sequence([1, 2, 3]).and_return() - - for example in DataPipeline.round_robin(pipeline1, pipeline2).and_return(): - print(example) - -produces 0, 1, 0, 2, 0, 3, 0, 1, 0, 2, 0, 3... indefinitely. - - -Public classes used in fairseq2 API: -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated/data - - ByteStreamError - DataPipelineError - RecordError - VocabularyInfo - -Helper methods: - -.. autosummary:: - :toctree: generated/data - - get_last_failed_example - -fairseq2.data.text -~~~~~~~~~~~~~~~~~~ - -Tools to tokenize text, converting it from bytes to tensors. - -.. currentmodule:: fairseq2.data.text - -.. autosummary:: - :toctree: generated/data_text - - TextTokenizer - TextTokenDecoder - TextTokenEncoder - - StrSplitter - StrToIntConverter - StrToTensorConverter - - SentencePieceModel - SentencePieceEncoder - SentencePieceDecoder - vocab_info_from_sentencepiece - LineEnding diff --git a/doc/reference/enums.rst b/doc/reference/enums.rst deleted file mode 100644 index fccc884ff..000000000 --- a/doc/reference/enums.rst +++ /dev/null @@ -1,11 +0,0 @@ -Enums -===== -.. body - -.. currentmodule:: fairseq2 - -.. autosummary:: - :toctree: generated/enums - :nosignatures: - - nn.transformer.TransformerNormOrder diff --git a/doc/reference/functions.rst b/doc/reference/functions.rst deleted file mode 100644 index a27bee021..000000000 --- a/doc/reference/functions.rst +++ /dev/null @@ -1,11 +0,0 @@ -Functions -========= -.. body - -.. currentmodule:: fairseq2 - -.. autosummary:: - :toctree: generated/functions - :nosignatures: - - nn.utils.mask.to_float_mask diff --git a/doc/requirements.txt b/doc/requirements.txt index 99f5a596b..31d5fdeef 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,3 +1,7 @@ -sphinx-rtd-theme~=1.2.2 -sphinx~=6.2.1 +sphinx~=7.4.0 sphinxcontrib-bibtex~=2.5.0 +sphinx-favicon~=1.0.1 +sphinx-design~=0.5.0 +myst-parser~=4.0.0 +sphinxcontrib-mermaid~=1.0.0 +furo==2024.8.6 \ No newline at end of file diff --git a/doc/source/_static/bibliography.bib b/doc/source/_static/bibliography.bib new file mode 100644 index 000000000..7a3bfa181 --- /dev/null +++ b/doc/source/_static/bibliography.bib @@ -0,0 +1,39 @@ +@misc{https://doi.org/10.48550/arxiv.1706.03762, + title={Attention Is All You Need}, + author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin}, + year={2023}, + eprint={1706.03762}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/1706.03762}, +} + +@misc{https://doi.org/10.48550/arxiv.2104.09864, + title={RoFormer: Enhanced Transformer with Rotary Position Embedding}, + author={Jianlin Su and Yu Lu and Shengfeng Pan and Ahmed Murtadha and Bo Wen and Yunfeng Liu}, + year={2023}, + eprint={2104.09864}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2104.09864}, +} + +@misc{https://doi.org/10.48550/arxiv.2302.13971, + title={LLaMA: Open and Efficient Foundation Language Models}, + author={Hugo Touvron and Thibaut Lavril and Gautier Izacard and Xavier Martinet and Marie-Anne Lachaux and Timothée Lacroix and Baptiste Rozière and Naman Goyal and Eric Hambro and Faisal Azhar and Aurelien Rodriguez and Armand Joulin and Edouard Grave and Guillaume Lample}, + year={2023}, + eprint={2302.13971}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2302.13971}, +} + +@misc{https://doi.org/10.48550/arxiv.2006.11477, + title={wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations}, + author={Alexei Baevski and Henry Zhou and Abdelrahman Mohamed and Michael Auli}, + year={2020}, + eprint={2006.11477}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2006.11477}, +} \ No newline at end of file diff --git a/doc/source/_static/img/data/complex_data_pipeline_example.mmd b/doc/source/_static/img/data/complex_data_pipeline_example.mmd new file mode 100644 index 000000000..b2bfa09a6 --- /dev/null +++ b/doc/source/_static/img/data/complex_data_pipeline_example.mmd @@ -0,0 +1,82 @@ +graph TD + %% Source Handling + subgraph SRC[Data Sources] + S3[S3] & BLOB[Blobstore] & HF[HuggingFace] & OTHER_SRC[Other Sources] + CONFIG[Dataset Config] + S3 & BLOB & HF & OTHER_SRC --> CONFIG + end + + %% Data Ingestion + subgraph INGEST[Data Ingestion] + READ[Parallel AsyncIO Reading] + LOAD{Load Strategy} + MEM[Memory Load] + STREAM[Stream Load] + + CONFIG --> READ + READ --> LOAD + LOAD -->|Full| MEM + LOAD -->|Stream| STREAM + end + + %% Preprocessing + subgraph PREP[Preprocessing] + SHUFFLE[Shuffle] + REPEAT[Repeat] + WEIGHT[Weight] + end + + MEM & STREAM --> PREP + + %% Sharding + subgraph SHARD[Sharding] + MULTI[Multi-GPU Setup] + EVEN[Even Sharding] + UNEVEN[Uneven Sharding] + + MULTI -->|Seed| EVEN + MULTI -->|Dynamic| UNEVEN + end + + PREP --> MULTI + + %% Data Processing + subgraph PROCESS[Data Processing] + direction TB + + subgraph FILTER[Filter] + FILTER_STD[Standard] + FILTER_CUSTOM[Custom] + end + + subgraph TRANS[Transform - map] + CLIP[Clip] & WRAP[Wrap] & AUG[Augment] & CLEAN[Clean] & OTHER[Other] + end + + subgraph BUCKET[Dynamic Bucketing] + TOK[By Token] & ROW[By Row] & COST_FN[By Cost Function] + end + + EVEN & UNEVEN --> FILTER + FILTER --> TRANS + TRANS --> BUCKET + end + + %% Output + subgraph OUT[Output] + YIELD[Dataset Yield] + FORMAT[Format Conversion] + + YIELD -->|Convert| FORMAT + end + + PROCESS --> OUT + + %% Styling + classDef primary fill:#eee,stroke:#333,stroke-width:2px + classDef secondary fill:#bbf,stroke:#333 + classDef action fill:#bfb,stroke:#333 + + class SRC,INGEST,PREP,SHARD,PROCESS,OUT primary + class CONFIG,READ,MULTI,TRANS,BUCKET,FORMAT secondary + class SHUFFLE,REPEAT,WEIGHT,CLIP,FILTER_STD,FILTER_CUSTOM,WRAP,AUG,CLEAN,OTHER action \ No newline at end of file diff --git a/doc/source/_static/img/data/complex_data_pipeline_example.svg b/doc/source/_static/img/data/complex_data_pipeline_example.svg new file mode 100644 index 000000000..8396281bc --- /dev/null +++ b/doc/source/_static/img/data/complex_data_pipeline_example.svg @@ -0,0 +1 @@ +

Data Processing

Sharding

Data Ingestion

Data Sources

Full

Stream

Seed

Dynamic

Output

Convert

Dataset Yield

Format Conversion

Dynamic Bucketing

By Token

By Row

By Cost Function

Transform - map

Clip

Wrap

Augment

Clean

Other

Filter

Standard

Custom

Preprocessing

Shuffle

Repeat

Weight

S3

Blobstore

HuggingFace

Other Sources

Dataset Config

Parallel AsyncIO Reading

Load Strategy

Memory Load

Stream Load

Multi-GPU Setup

Even Sharding

Uneven Sharding

\ No newline at end of file diff --git a/doc/source/_static/img/gang.svg b/doc/source/_static/img/gang.svg new file mode 100644 index 000000000..6277518d0 --- /dev/null +++ b/doc/source/_static/img/gang.svg @@ -0,0 +1,10 @@ + + + + + + + + GPU 0GPU 1GPU 2GPU 3TP 0TP 1GPU 4GPU 5GPU 6GPU 7TP 2TP 3DP 0DP 1Node 0Node 1Layer 0Layer 1Layer 2Layer 3Layer 4Layer 5ModelDataBatch 0Batch 1Batch 2Batch 3shard 0shard 1 \ No newline at end of file diff --git a/doc/source/_static/img/logo.png b/doc/source/_static/img/logo.png new file mode 100644 index 000000000..1830151c1 --- /dev/null +++ b/doc/source/_static/img/logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:850abbe321cffb628d19cb5c20585ae87659ad265ec8635c57fd91ffcd81d742 +size 73036 diff --git a/doc/source/_static/img/position_encoder.svg b/doc/source/_static/img/position_encoder.svg new file mode 100644 index 000000000..46425a4d5 --- /dev/null +++ b/doc/source/_static/img/position_encoder.svg @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + torch.nn.Module + + + + + + + SinusoidalPositionEncoder + + + + + + + RotaryEncoder + + + + + + + LearnedPositionEncoder + + + + + + + PositionEncoder + + + + + + + + + + + diff --git a/doc/source/_static/img/text_tokenizer.svg b/doc/source/_static/img/text_tokenizer.svg new file mode 100644 index 000000000..158f8fdcd --- /dev/null +++ b/doc/source/_static/img/text_tokenizer.svg @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + AbstractTextTokenizer + + + + + + + SentencePieceTokenizer + + + + + + + TiktokenTokenizer + + + + + + + TextTokenizer + + + + + + + + + diff --git a/doc/source/_static/img/tutorials/benchmark/2node_elapsed_time_relative.png b/doc/source/_static/img/tutorials/benchmark/2node_elapsed_time_relative.png new file mode 100644 index 000000000..84ad30436 --- /dev/null +++ b/doc/source/_static/img/tutorials/benchmark/2node_elapsed_time_relative.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d195a75592419227fcb06e1a559f0faa26a8574036d7f9391e1e7964604e2839 +size 557581 diff --git a/doc/source/_static/img/tutorials/benchmark/2node_eps_absolute.png b/doc/source/_static/img/tutorials/benchmark/2node_eps_absolute.png new file mode 100644 index 000000000..256e0ff4c --- /dev/null +++ b/doc/source/_static/img/tutorials/benchmark/2node_eps_absolute.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ed25afafabccc3d1fecf4c9132fd9e64a21d848df0911112a2f3bf516d97b53 +size 471716 diff --git a/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_accuracy.png b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_accuracy.png new file mode 100644 index 000000000..65038a03d --- /dev/null +++ b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_accuracy.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09e6d1ccdcb9fb2148d13f0e52552e87ea704aae62d1679f9ca8c342b6fff62b +size 132866 diff --git a/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_elements_per_second.png b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_elements_per_second.png new file mode 100644 index 000000000..046b91575 --- /dev/null +++ b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_elements_per_second.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd5b890162d7e17bba62b73a1b4b2777882dcd28f253851b535116ade483640f +size 130125 diff --git a/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_trace.png b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_trace.png new file mode 100644 index 000000000..4c5d98b22 --- /dev/null +++ b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_trace.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe8b1f8dd4423b35e78ec7254d7a8988380b956ab6403c4c9ae9d510ca8a7c03 +size 1105940 diff --git a/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_wandb.png b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_wandb.png new file mode 100644 index 000000000..ea3410e65 --- /dev/null +++ b/doc/source/_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_wandb.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d6a92ff1bde7808c0068b18eeacdec9ae7ff8795054b5838181cd7295e39509 +size 855736 diff --git a/doc/source/_static/img/tutorials/presets/tutorial_presets_benchmark.png b/doc/source/_static/img/tutorials/presets/tutorial_presets_benchmark.png new file mode 100644 index 000000000..9f67266d6 --- /dev/null +++ b/doc/source/_static/img/tutorials/presets/tutorial_presets_benchmark.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e512c6365eebbd9fb5619627382b36cc551e9cf3b2de2508abf599d6c420cccb +size 1657461 diff --git a/doc/source/_static/img/tutorials/pudb.png b/doc/source/_static/img/tutorials/pudb.png new file mode 100644 index 000000000..650ae6b20 --- /dev/null +++ b/doc/source/_static/img/tutorials/pudb.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:162763752f5977b570b97fc0258e3c87a9d37527e0862a5522852f495c3b5bdc +size 186828 diff --git a/doc/templates/footer.html b/doc/source/_templates/footer.html similarity index 100% rename from doc/templates/footer.html rename to doc/source/_templates/footer.html diff --git a/doc/source/basics/assets.rst b/doc/source/basics/assets.rst new file mode 100644 index 000000000..8e37ed83d --- /dev/null +++ b/doc/source/basics/assets.rst @@ -0,0 +1,108 @@ +.. _basics-assets: + +=========================== +:octicon:`container` Assets +=========================== + +.. currentmodule:: fairseq2.assets + +In fairseq2, "assets" refer to the various components that make up a sequence or language modeling task, such as datasets, models, tokenizers, etc. These assets are essential for training, evaluating, and deploying models. +``fairseq2.assets`` provides API to load the different models using the "model cards" from different "stores". + +Cards: YAML Files in fairseq2 +----------------------------- + +To organize these assets, fairseq2 uses a concept called "cards," which are essentially YAML files that describe the assets and their relationships. +For example, you can find all the "cards" in fairseq2 `here `__. +Cards provide a flexible way to define and manage the various components of an NLP task, making it easier to reuse, share, and combine different assets. + +How Cards Help Organize Assets +------------------------------ + +* **Asset Definition**: Cards define the assets used in an NLP task, including datasets, models, tokenizers, and other resources. + +* **Relationship Management**: Cards specify the relationships between assets, such as which dataset is used with which model or tokenizer. + +* **Reusability**: Cards enable reusability of assets across different tasks and projects, reducing duplication and increasing efficiency. + +* **Sharing and Collaboration**: Cards facilitate sharing and collaboration by providing a standardized way to describe and exchange assets. + + +How to Customize Your Assets +---------------------------- + +* How to add a dataset + + * Make sure that you have the dataset in place + + * Add the ``name``, ``dataset_family``, and ``data`` fields, which allows fairseq2 to find the corresponding dataset loader + + * For more detailed information about ``dataset_family``, please refer to :doc:`Dataset Loaders ` + +.. code-block:: yaml + + name: gsm8k_sft + dataset_family: generic_instruction + + --- + + name: gsm8k_sft@awscluster + data: "/data/gsm8k_data/sft" + + +* How to add a model + + * Make sure that you have the model checkpoint + + * Add the ``name`` and ``checkpoint`` fields + +.. code-block:: yaml + + name: llama3_2_1b@awscluster + checkpoint: "/models/Llama-3.2-1B/original/consolidated.00.pth" + + +Advanced Topics +--------------- + +Model Store +~~~~~~~~~~~ + +A store is a place where all the model cards are stored. In fairseq2, a store is accessed via +:py:class:`fairseq2.assets.AssetStore`. Multiple stores are allowed. By default, fairseq2 will look up the following stores: + +* System asset store: Cards that are shared by all users. By default, the system store is `/etc/fairseq2/assets`, + but this can be changed via the environment variable `FAIRSEQ2_ASSET_DIR` + +* User asset store: Cards that are only available to the user. By default, the user store is + `~/.config/fairseq2/assets`, but this can be changed via the environment variable `FAIRSEQ2_USER_ASSET_DIR` + +To register a new store, implement a :py:class:`fairseq2.assets.AssetMetadataProvider` and add them to +:py:class:`fairseq2.assets.asset_store`. Here is an example to register a new directory as a model store: + +.. code-block:: python + + from pathlib import Path + from fairseq2.assets import FileAssetMetadataProvider, asset_store + + my_dir = Path("/path/to/model_store") + asset_store.metadata_providers.append(FileAssetMetadataProvider(my_dir)) + + +Model Card +~~~~~~~~~~ + +A model card is a .YAML file that contains information about a model and instructs a +:py:class:`fairseq2.models.utils.generic_loaders.ModelLoader` on how to load the model into the memory. Each model card +must have 2 mandatory attributes: `name` and `checkpoint`. `name` will be used to identify the model card, and it must +be unique `across` all +fairseq2 provides example cards for different LLMs in +:py:mod:`fairseq2.assets.cards`. + +In fairseq2, a model card is accessed via :py:class:`fairseq2.assets.AssetCard`. Alternatively, one can call +`fairseq2.assets.AssetMetadataProvider.get_metadata(name: str)` to get the meta data of a given model card name. + +See Also +-------- + +- :doc:`Datasets ` diff --git a/doc/source/basics/ckpt.rst b/doc/source/basics/ckpt.rst new file mode 100644 index 000000000..c59fafe61 --- /dev/null +++ b/doc/source/basics/ckpt.rst @@ -0,0 +1,173 @@ +.. _basics-ckpt-management: + +:octicon:`check-circle` Checkpoint Management +============================================= + +The checkpoint manager in fairseq2 handles saving and loading of model states, optimizer states, and training progress. +It provides a robust way to: + +- Save model checkpoints during training +- Load checkpoints to resume training +- Manage multiple checkpoints with policies like keeping N-best or last N checkpoints +- Handle distributed training scenarios including FSDP (Fully Sharded Data Parallel) + +Architecture Overview +--------------------- + +.. mermaid:: + + graph TD + A[Trainer] -->|uses| B[CheckpointManager] + B -->|saves| C[Model State] + B -->|saves| D[Optimizer State] + B -->|saves| E[Training Metadata] + B -->|manages| F[Checkpoint Files] + G[Model Loader] -->|loads| B + +Basic Usage +----------- + +Saving Checkpoints +^^^^^^^^^^^^^^^^^^ + +The :class:`fairseq2.checkpoint.manager.CheckpointManager` provides a transactional API for saving checkpoints: + +.. code-block:: python + + # Initialize checkpoint manager + ckpt_manager = FileCheckpointManager( + checkpoint_dir=Path("checkpoints"), + gang=root_gang # For distributed training coordination + ) + + # Begin checkpoint operation + ckpt_manager.begin_checkpoint(step_nr=1000) + + # Save model and optimizer state + ckpt_manager.save_state({ + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "step_nr": 1000, + "epoch": 5 + }) + + # Save validation score if needed + ckpt_manager.save_score(valid_score) + + # Commit the checkpoint + ckpt_manager.commit_checkpoint() + +Loading Checkpoints +^^^^^^^^^^^^^^^^^^^ + +To load the latest checkpoint: + +.. code-block:: python + + try: + # Load the last checkpoint + step_nr, state = ckpt_manager.load_last_checkpoint() + + # Restore model and optimizer state + model.load_state_dict(state["model"]) + optimizer.load_state_dict(state["optimizer"]) + + print(f"Restored checkpoint from step {step_nr}") + except CheckpointNotFoundError: + print("No checkpoint found, starting fresh") + +Checkpoint Management Policies +------------------------------ + +The :class:`fairseq2.checkpoint.manager.CheckpointManager` supports different policies for managing multiple checkpoints: + +Keep Last N Checkpoints +^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Keep only the last 5 checkpoints + ckpt_manager.keep_last_n_checkpoints(n=5) + +Keep Best N Checkpoints +^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Keep the 3 checkpoints with best validation scores + ckpt_manager.keep_best_n_checkpoints( + n=3, + lower_better=True # True if lower scores are better + ) + +Distributed Training Support +---------------------------- + +The :class:`fairseq2.checkpoint.manager.CheckpointManager` handles distributed training scenarios including: + +- Data Parallel (DP) training +- Fully Sharded Data Parallel (FSDP) training +- Tensor Parallel (TP) training + +For FSDP, the manager provides special handling: + +.. code-block:: python + + # Save consolidated (non-sharded) model state + ckpt_manager.save_consolidated_fsdp_model(model) + +Checkpoint Structure +-------------------- + +A checkpoint directory contains: + +.. code-block:: text + + checkpoint_dir/ + ├── model.yaml # Model metadata + └── step_1000/ # Checkpoint at step 1000 + └── model.pt # Model training state + +For sharded checkpoints (FSDP), each rank has its own files: + +.. code-block:: text + + checkpoint_dir/ + ├── model.yaml # Model metadata + └── step_1000/ + ├── model.pt # Consolidated model + ├── rank_0.pt # Model rank 0 state + └── rank_1.pt # Model rank 1 state + +Error Handling +-------------- + +The checkpoint system provides specific exceptions for error cases: + +- ``CheckpointError``: Base class for checkpoint-related errors +- ``CheckpointNotFoundError``: Raised when attempting to load non-existent checkpoint +- ``InvalidOperationError``: Raised for invalid checkpoint operations + +Example error handling: + +.. code-block:: python + + try: + ckpt_manager.load_checkpoint(step_nr=1000) + except CheckpointNotFoundError: + print("Checkpoint not found") + except CheckpointError as e: + print(f"Error loading checkpoint: {e}") + +Best Practices +-------------- + +1. Always use the transactional API (``begin_checkpoint``/``commit_checkpoint``) to ensure checkpoint consistency + +2. Implement checkpoint cleanup policies to manage storage space + +3. Include sufficient metadata in checkpoints for reproducibility + +4. Handle checkpoint errors gracefully in production code + +5. For distributed training, ensure proper gang coordination diff --git a/doc/source/basics/cli.rst b/doc/source/basics/cli.rst new file mode 100644 index 000000000..95204b3e4 --- /dev/null +++ b/doc/source/basics/cli.rst @@ -0,0 +1,128 @@ +.. _basics-cli: + +:octicon:`terminal` CLI +======================= + +The Command-Line Interface (CLI) is a crucial feature in fairseq2, offering users a powerful and flexible way to interact with the framework. +With the CLI, you can quickly and easily execute tasks, customize recipes and configurations, and perform complex operations such as sweep runs and benchmarking. + +Basic Usage +----------- + +Here are some basic examples of using the CLI: + +.. code-block:: bash + + # Get help about available commands + fairseq2 -h + + # Get help about a specific command group (e.g. recipe lm) + fairseq2 lm -h + + # Get help about a specific command (e.g. recipe lm::instruction_finetune) + fairseq2 lm instruction_finetune -h + + # List available presets for a recipe (e.g. recipe lm::instruction_finetune) + fairseq2 lm instruction_finetune --list-presets + + # Dump the default configuration for a recipe (e.g. recipe lm::instruction_finetune) + fairseq2 lm instruction_finetune --dump-config + + # Run a recipe with default settings (e.g. recipe lm::instruction_finetune) + fairseq2 lm instruction_finetune + + # Run a recipe with a custom config file (e.g. recipe lm::instruction_finetune) + fairseq2 lm instruction_finetune --config-file .yaml + +Configuration Customization +--------------------------- + +fairseq2 provides multiple ways to customize recipe configurations: + +1. Using Config Files +^^^^^^^^^^^^^^^^^^^^^ + +You can specify one or multiple YAML config files: + +.. code-block:: bash + + # Single config file + fairseq2 lm instruction_finetune --config-file config1.yaml + + # Multiple config files (merged from left to right) + fairseq2 lm instruction_finetune --config-file base.yaml --config-file override.yaml + +2. Command Line Overrides +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``--config`` to override specific values: + +.. code-block:: bash + + # Override single value + fairseq2 lm instruction_finetune --config max_num_tokens=512 + + # Override nested values + fairseq2 lm instruction_finetune --config optimizer_config.lr=4e-5 + + # Override multiple values + fairseq2 lm instruction_finetune --config max_num_tokens=512 max_seq_len=512 + + # Override a tuple + fairseq2 lm instruction_finetune --config profile="[500,10]" + +.. note:: + + Unlike ``--config-file``, only one ``--config`` argument can be used. + +3. Adding and Removing Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``add:`` and ``del:`` directives for more advanced configuration: + +.. code-block:: bash + + # Add a new configuration value + fairseq2 lm instruction_finetune --config add:new_param=value + + # Remove a configuration value + fairseq2 lm instruction_finetune --config del:unwanted_param + +4. Combining Different Methods +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can combine all these methods, with later values taking precedence: + +.. code-block:: bash + + fairseq2 lm instruction_finetune \ + --config-file base.yaml \ + --config-file override.yaml \ + --config max_num_tokens=512 \ + optimizer_config.lr=4e-5 \ + add:custom_param=value + +Asset Management +---------------- + +fairseq2 provides commands to manage and inspect assets: + +.. code-block:: bash + + # List all available assets + fairseq2 assets list + + # Show details of a specific asset + fairseq2 assets show llama3_1_8b_instruct + + # List assets filtered by type + fairseq2 assets list --type model + fairseq2 assets list --type dataset + fairseq2 assets list --type tokenizer + +See More +-------- + +For more technical details about implementing custom CLIs and extensions, see: + +- :doc:`/reference/api/fairseq2.recipes/cli` diff --git a/doc/source/basics/data_pipeline.rst b/doc/source/basics/data_pipeline.rst new file mode 100644 index 000000000..38948a91a --- /dev/null +++ b/doc/source/basics/data_pipeline.rst @@ -0,0 +1,249 @@ +.. _basics-data-pipeline: + +:octicon:`database` Data Pipeline +================================= + +Data pipelines in fairseq2 provide an efficient way to process and transform data for machine learning tasks. +The implementation leverages multiple threads to work around Python's Global Interpreter Lock (GIL) limitations, +resulting in better performance than pure Python dataloaders. + +Basic Pipeline Structure +^^^^^^^^^^^^^^^^^^^^^^^^ + +A data pipeline consists of a series of operations that transform data. Here's a basic example:: + + data = ( + text.read_text("file.tsv") + .map(lambda x: str(x.split("\t")[1]).lower()) + .filter(lambda x: len(x) < 10) + ) + +.. mermaid:: + + graph LR + A[read_text] --> B[map] + B --> C[filter] + style A fill:#f9f,stroke:#333 + style B fill:#bbf,stroke:#333 + style C fill:#bfb,stroke:#333 + +.. dropdown:: A more complex pipeline that can be built w/ fairseq2 as a diagram + :icon: package + :animate: fade-in + + .. image:: /_static/img/data/complex_data_pipeline_example.svg + :alt: A more complex pipeline that can be built w/ fairseq2 + +.. _basics/data-pipeline/column-selection: + +Column Selection +^^^^^^^^^^^^^^^^ + +Data items in the pipeline can be tuples or Python dictionaries. Many operators support a `selector` argument to specify which column to process: + +- For tuples: ``"[3]"`` selects the fourth element (0-based indexing) +- For dictionaries: ``"foo"`` selects the value for key ``"foo"`` +- Nested selectors: Use ``.`` to traverse nested structures (e.g., ``"foo[1].y"``) + +Example with nested data:: + + data = {"foo": [{"x": 1, "y": 2}, {"x": 3, "y": 4, "z": 5}], "bar": 6} + # "foo[1].y" selects 4 + # "bar" selects 6 + +.. mermaid:: + + graph TD + A[Input Dictionary] --> B[foo] + A --> C[bar: 6] + B --> D[List Index 0] + B --> E[List Index 1] + D --> F[x: 1] + D --> G[y: 2] + E --> H[x: 3] + E --> I[y: 4] + E --> J[z: 5] + style I fill:#f96,stroke:#333 + style C fill:#f96,stroke:#333 + +.. _basics/data-pipeline/pipeline-types: + +Pipeline Types +^^^^^^^^^^^^^^ + +fairseq2 supports three types of pipelines: + +1. **Finite Pipelines**: Standard pipelines that terminate after processing all data +2. **Pseudo-infinite Pipelines**: Created using ``DataPipeline.count`` or ``DataPipeline.constant`` +3. **Infinite Pipelines**: Created using ``DataPipelineBuilder.repeat`` without arguments + +.. mermaid:: + + graph TD + subgraph Finite + A[read_sequence] --> B[End] + end + subgraph Pseudo-infinite + C[constant/count] --> D[Stops with other pipelines] + end + subgraph Infinite + E[repeat] --> F[Never ends] + end + +.. _basics/data-pipeline/combining-pipelines: + +Combining Pipelines +^^^^^^^^^^^^^^^^^^^ + +fairseq2 provides several ways to combine pipelines: + +1. **Round Robin**: Alternates between pipelines:: + + pipeline1 = DataPipeline.constant(0).and_return() + pipeline2 = read_sequence([1, 2, 3]).and_return() + + for example in DataPipeline.round_robin(pipeline1, pipeline2).and_return(): + print(example) + + # round_robin yields: 0, 1, 0, 2, 0, 3 + +2. **Zip**: Combines examples from multiple pipelines:: + + pipeline1 = read_sequence([0]).repeat().and_return() + pipeline2 = read_sequence([1, 2, 3]).and_return() + + for example in DataPipeline.zip(pipeline1, pipeline2, names=["a", "b"]).and_return(): + print(example) + + # Yields: {"a": 0, "b": 1}, {"a": 0, "b": 2}, {"a": 0, "b": 3} + +3. **Sample**: Randomly samples from pipelines based on weights:: + + pipeline1 = read_sequence([0]).repeat().and_return() + pipeline2 = read_sequence([1, 2, 3]).and_return() + + for example in DataPipeline.sample(pipeline1, pipeline2, weights=[0.5, 0.5]).and_return(): + print(example) + +.. mermaid:: + + graph TD + subgraph Round Robin + A1[Pipeline 1] --> C1{Alternate} + B1[Pipeline 2] --> C1 + C1 --> D1[Output] + end + subgraph Zip + A2[Pipeline 1] --> C2((Combine)) + B2[Pipeline 2] --> C2 + C2 --> D2[Output] + end + subgraph Sample + A3[Pipeline 1] --> C3{Random Select} + B3[Pipeline 2] --> C3 + C3 --> D3[Output] + end + +More Features +^^^^^^^^^^^^^ + +Shuffling +~~~~~~~~~ + +fairseq2 provides flexible shuffling capabilities through the ``shuffle`` operator: + +.. code-block:: python + + # Basic shuffling with a window size + pipeline = ( + read_sequence(data) + .shuffle(shuffle_window=1000) # Shuffle using a 1000-example buffer + .and_return() + ) + + # Shuffle between epochs + for epoch in range(3): + pipeline.reset() # By default, this re-shuffles data + for item in pipeline: + process(item) + + # Disable shuffling between epochs + pipeline.reset(reset_rng=True) # Keep the same order + +The shuffle operator maintains a buffer of the specified size. +When requesting the next example, it randomly samples from this buffer and replaces the selected example with a new one from the source. +Setting ``shuffle_window=0`` loads all examples into memory for full shuffling. + +Bucketing +~~~~~~~~~ + +Bucketing helps handle variable-length sequences efficiently. There are several bucketing strategies: + +1. **Fixed-size Bucketing**: Combine a fixed number of examples + +.. code-block:: python + + pipeline = ( + read_sequence(data) + .bucket(bucket_size=32, drop_remainder=True) # Combine 32 examples into one bucket + .and_return() + ) + +2. **Length-based Bucketing**: Group sequences of similar lengths + +.. code-block:: python + + from fairseq2.data import create_bucket_sizes + + # Create optimal bucket sizes + bucket_sizes = create_bucket_sizes( + max_num_elements=1024, # Max elements per bucket + max_seq_len=128, # Max sequence length + min_seq_len=1, # Min sequence length + num_seqs_multiple_of=8 # Ensure bucket sizes are multiples of 8 + ) + + # Use bucketing in pipeline + pipeline = ( + read_sequence(data) + .bucket_by_length( + bucket_sizes, + selector="length", # Column containing sequence lengths + skip_above_max_examples=True, # Skip sequences longer than max_seq_len + drop_remainder=False # Keep partial buckets + ) + .and_return() + ) + +3. **Dynamic Bucketing**: Combine examples based on a cost function + +.. code-block:: python + + def sequence_cost(example): + return len(example["text"]) + + pipeline = ( + read_sequence(data) + .dynamic_bucket( + threshold=1024, # Target bucket size + cost_fn=sequence_cost, # Function to compute example cost + min_num_examples=16, # Min examples per bucket + max_num_examples=64, # Max examples per bucket + drop_remainder=False # Keep partial buckets + ) + .and_return() + ) + + +This approach efficiently handles variable-length sequences while maintaining appropriate batch sizes for training. + +There are more features in fairseq2's data pipeline: + +- **Prefetching**: Load data ahead of time for better performance +- **State Management**: Save and restore pipeline state for resumable processing + +.. note:: + When combining pseudo-infinite pipelines with finite ones, the pseudo-infinite pipeline will stop when the finite pipeline ends. + For truly infinite behavior, use ``repeat()`` without arguments. + +For more technical details, see :doc:`/reference/api/fairseq2.data/data_pipeline`. \ No newline at end of file diff --git a/doc/source/basics/design_philosophy.rst b/doc/source/basics/design_philosophy.rst new file mode 100644 index 000000000..57fdeb766 --- /dev/null +++ b/doc/source/basics/design_philosophy.rst @@ -0,0 +1,75 @@ +.. _basics-design-philosophy: + +===================================== +:octicon:`infinity` Design Philosophy +===================================== + +One of the core goals of fairseq2 is to make it possible for researchers to +explore new ideas and implement novel features without having to fork fairseq2. +Instead of having a monolithic repository that can only be modified by +copy-pasting large chunks of code, in fairseq2, all major APIs follow the +interface/implementation convention along with the `dependency inversion principle`__. +This means, each API has an *interface* (i.e. an abstract :class:`~abc.ABC` +class) that defines the contract of that API, and one or more concrete +implementations of that interface. Different implementations can be integrated +with the rest of fairseq2 via its lightweight `dependency injection API`__. + +.. __: https://en.wikipedia.org/wiki/Dependency_inversion_principle +.. __: https://en.wikipedia.org/wiki/Dependency_injection + +Interface/Implementation Convention +=================================== + +.. currentmodule:: fairseq2.nn + +The diagram below shows the :doc:`position encoder API ` +as an example. The API is defined by the abstract :class:`PositionEncoder` +PyTorch module. :class:`SinusoidalPositionEncoder`, :class:`LearnedPositionEncoder`, +and :class:`RotaryEncoder` implement :class:`PositionEncoder` for their +respective algorithms. Technically, any of these position encoders can be used +wherever a :class:`PositionEncoder` is expected (see `Dependency Inversion`_ +below). + +.. image:: /_static/img/position_encoder.svg + :width: 580px + :align: center + :alt: Position Encoder Hierarchy + +.. currentmodule:: fairseq2.data.text + +When several implementations of an API share common logic, a typical pattern is +to have an intermediate abstract class, prefixed with ``Abstract``, between the +interface and the concrete implementations. For example, the :doc:`text tokenizer +API ` has :class:`AbstractTextTokenizer` +that holds the common logic for :class:`SentencePieceTokenizer` and +:class:`TiktokenTokenizer`. + +.. image:: /_static/img/text_tokenizer.svg + :width: 580px + :align: center + :alt: Text Tokenizer Hierarchy + +Dependency Inversion +==================== + +.. currentmodule:: fairseq2.nn.transformer + +The dependency inversion principle is critical to have a clean, well-tested, and +extensible API. The example below shows the (abbreviated) ``__init__()`` method +of the :class:`StandardTransformerDecoderLayer`:: + + class StandardTransformerDecoderLayer(TransformerDecoderLayer): + def __init__( + self, + self_attn: MultiheadAttention, + encoder_decoder_attn: MultiheadAttention | None, + ffn: FeedForwardNetwork + ) -> None: + ... + +Instead of constructing the multihead attention and feed-forward network layers +within its ``__init__()`` method, :class:`StandardTransformerDecoderLayer` +expects the caller to provide instances of :class:`MultiheadAttention` and +:class:`FeedForwardNetwork` interfaces. This loose-coupling between an instance +and its dependencies enables composing diverse object graphs, such as different +model architectures, with minimal redundancy (i.e. code duplication). diff --git a/doc/source/basics/gang.rst b/doc/source/basics/gang.rst new file mode 100644 index 000000000..5c5482918 --- /dev/null +++ b/doc/source/basics/gang.rst @@ -0,0 +1,312 @@ +.. _basics-gang: + +:octicon:`table` Gang +===================== + + +Overview +-------- + +Gang is fairseq2's abstraction for distributed training that provides a clean interface for collective operations (`e.g.`, ``all_reduce``, ``all_gather``, and ``broadcast``) across processes in a distributed environment. +It simplifies PyTorch's distributed training while supporting both data parallelism and tensor parallelism. + +This design encapsulates the complexity of PyTorch's ``torch.distributed`` while supporting: + +- **Data Parallelism**: Distributing batches of data across multiple GPUs. +- **Tensor Parallelism**: Partitioning model tensors for efficient computation. +- **Flexible Process Grouping**: Organizing processes into groups dynamically. + +Core Concepts +------------- + +.. note:: + + It would be helpful to understand the following concepts before diving into Gang: + + - `PyTorch Distributed `_ + - `FSDP `_ + - `Distributed Device Mesh `_ + +What's Gang? +^^^^^^^^^^^^ + +A Gang represents a group of processes (`e.g.`, GPUs) that work together in a distributed setting. +Each Gang: + +- Has a unique rank for each process +- Knows its total size (number of processes) +- Supports collective operations (`e.g.`, ``all_reduce``, ``broadcast``) +- Is associated with a specific device (CPU or CUDA) + +By abstracting the concept of "process groups" from PyTorch Distributed, Gangs make distributed training simpler and more expressive. + +Types of Gangs +^^^^^^^^^^^^^^ + +1. **FakeGang** + + - A non-distributed gang for single-process execution + - Useful for local development and debugging + - Emulates distributed operations locally + +2. **ProcessGroupGang** + + - Wraps PyTorch's ProcessGroup for actual distributed training + - Supports both NCCL (for GPU) and Gloo (for CPU) backends + - Handles monitored barriers and collective operations (e.g., `all_reduce`, `all_gather`, `broadcast`) + +Distributed Training Basics +--------------------------- + +Key Terms +^^^^^^^^^ + +1. **World Size**: The total number of processes participating in distributed training. +2. **Rank**: The unique ID of a process within the world. +3. **Device**: The hardware (CPU/GPU) associated with each process +4. **Process Group**: A subset of processes for performing collective operations. + +Collective Operations +^^^^^^^^^^^^^^^^^^^^^ + +The Gang interface supports the following methods: + +.. code-block:: python + + # Reduce tensor across processes + gang.all_reduce(tensor, ReduceOperation.SUM) + + # Gather tensors from all processes + gang.all_gather(output_tensor, input_tensor) + + # Gather tensors from all processes into a list + gang.all_gather_to_list(output_tensors, input_tensor) + + # Broadcast tensor from source rank to all others + gang.broadcast(tensor, source_rank=0) + + # Synchronize all processes + gang.barrier() + + # Broadcast Python objects + gang.broadcast_objects(objects, source_rank=0) + +Parallel Training Architecture +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In fairseq2, parallel training is organized around Data Parallel (DP) Gangs and Tensor Parallel (TP) Gangs, which together enable scalable training of large models. +For example, the ``setup_parallel_gangs(root_gang, tp_size=2)`` function creates a root gang (e.g., 8 processes) and then creates 2 DP gangs and 4 TP gangs. + +.. image:: /_static/img/gang.svg + :width: 600px + :align: center + :alt: Gang Architecture + +Structure and Organization of DP and TP Gangs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. Data Parallel (DP) Gangs: + + - Group GPUs that process different data batches (or parts of batches). + - Synchronize gradients across the GPUs in the same DP Gang after the backward pass. + - Example: DP Gang 1 has GPUs 0, 2, 4, and 6, while DP Gang 2 has GPUs 1, 3, 5, and 7. + +2. Tensor Parallel (TP) Gangs: + + - Group GPUs that split the model parameters for parallel computation. + - Operate within the same DP Gang but compute sequentially during forward and backward passes. + - Example: TP Gang 1 has GPUs 0 and 1, while TP Gang 2 has GPUs 2 and 3. + +**A Single Training Step** + +1. Forward Pass: + + - Input data is distributed among Data Parallel (DP) Gangs. + - Each Tensor Parallel (TP) Gang processes its segment of the model sequentially, transferring activations between GPUs. + +2. Backward Pass: + + - Gradients are calculated in the reverse sequence of the forward pass within TP Gangs. + - Activation gradients are relayed back to preceding GPUs. + +3. Gradient Synchronization: + + - Gradients are synchronized across all GPUs within each DP Gang. + +4. Parameter Update: + + - Each GPU updates its local parameters (or shards, if utilizing TP). + +.. dropdown:: How step-by-step parallel training works + :icon: code + :animate: fade-in + + - Step 1: Data Splitting + - The global input batch is divided into sub-batches, each assigned to a specific DP Gang + + - Step 2: Forward Pass (TP Gangs) + - Each TP Gang processes its shard of the model sequentially: + - GPU 0 (TP Gang 1) computes layers 0-2, passing activations to GPU 1. + - GPU 1 (TP Gang 1) computes layers 3-5 using these activations. + - This process is repeated for all TP Gangs. + + - Step 3: Backward Pass (TP Gangs) + - The reverse order of the forward pass: + - Gradients of layers 2-3 are computed on GPU 1. + - Activation gradients are sent back to GPU 0, which computes gradients for layers 0-1. + + - Step 4: Gradient Synchronization (DP Gangs) + - Gradients are synchronized across GPUs within the same DP Gang using an ``all_reduce`` operation. + + - Step 5: Parameter Update + - Each GPU updates its parameters or model shards locally after synchronization. + +.. dropdown:: The list of environment variables picked up by fairseq2 + :icon: code + :animate: fade-in + + The following environment variables control distributed training: + + - ``WORLD_SIZE``: Total number of processes. + - ``RANK``: Rank of the current process. + - ``LOCAL_WORLD_SIZE``: Number of processes per node. + - ``LOCAL_RANK``: Local rank within a node. + - ``MASTER_ADDR``: Address of rank 0 process + - ``MASTER_PORT``: Port for rank 0 process + + ``torchrun`` and SLURM automatically sets these variables. + + +Usage Examples +-------------- + +1. Basic Gang Setup +^^^^^^^^^^^^^^^^^^^ + +For standard distributed training: + +.. code-block:: python + + from fairseq2.gang import setup_default_gang + + # Initialize the default gang + gang = setup_default_gang() + + print(f"Process rank: {gang.rank}, World size: {gang.size}") + + +.. note:: + + If running locally (no ``torch.distributed`` backend), a ``FakeGang`` is created. + This is useful for local testing and debugging. + + If running in a distributed environment, a ``ProcessGroupGang`` is created. + +2. Create a Sub-Gang +^^^^^^^^^^^^^^^^^^^^^ + +You can create sub-groups of processes (`e.g.`, for model parallelism): + +.. code-block:: python + + sub_gang = gang.make_gang([0, 1, 2]) + if sub_gang: + print(f"Sub-gang rank: {sub_gang.rank}, Size: {sub_gang.size}") + +3. Data & Tensor Parallelism +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from fairseq2.gang import setup_parallel_gangs + + # Setup root gang first + root_gang = setup_default_gang() + + # Create DP and TP gangs with tensor parallel size = 2 + gangs = setup_parallel_gangs(root_gang, tp_size=2) + + print(f"Data Parallel Rank: {gangs.dp.rank}") + print(f"Tensor Parallel Rank: {gangs.tp.rank}") + + +4. Collective Operations +^^^^^^^^^^^^^^^^^^^^^^^^ + +A minimal example of distributed training with gangs: + +.. code-block:: python + + # script.py + import torch + from fairseq2.gang import setup_default_gang, ReduceOperation + + # Initialize gang + gang = setup_default_gang() + + # Dummy tensor + tensor = torch.tensor(gang.rank + 1.0, device=gang.device) + + # Sum tensor across all processes + gang.all_reduce(tensor, ReduceOperation.SUM) + print(f"Rank {gang.rank}: Tensor after all_reduce = {tensor.item()}") + + # Synchronize + gang.barrier() + + +To run this example w/ torchrun: + +.. code-block:: bash + + torchrun --nproc_per_node=4 script.py + + +Best Practices +-------------- + +1. **Development Workflow** + + - Start with ``FakeGang`` for local development + - Move to distributed training once code works locally + - Use monitored barriers to detect deadlocks + +2. **Process Layout** + + - Place adjacent ranks on same node for TP efficiency + - Balance DP and TP sizes based on model and data characteristics + +3. **Launching Jobs** + + - Use ``torchrun`` for simple distributed training: + + .. code-block:: bash + + torchrun --nproc_per_node=4 train.py + + - Use SLURM for cluster environments: + + .. code-block:: bash + + srun -N 1 --gres=gpu:4 --cpus-per-task=12 python train.py + + +4. **Error Handling** + + - Always synchronize processes with barriers at critical points + - Monitor for process failures in production settings + - Enable logging for debugging distributed issues + +5. **Device Placement** + + - Ensure tensors are on correct devices before collective ops + - Use ``gang.device`` to get the appropriate device + +6. **Resource Management** + + - Close gangs properly when done + +See Also +-------- + +- :ref:`basics-trainer` - How Gang integrates with training diff --git a/doc/source/basics/overview.rst b/doc/source/basics/overview.rst new file mode 100644 index 000000000..e1d235c2c --- /dev/null +++ b/doc/source/basics/overview.rst @@ -0,0 +1,107 @@ +.. _basics-overview: + +:octicon:`telescope` Overview +============================= + +fairseq2 is a sequence modeling framework designed for research and deployment of advanced machine learning models. +It builds on the strengths of its predecessor, fairseq, with modernized architecture, enhanced extensibility, and streamlined workflows for training and fine-tuning. + +Key Features +------------ + +.. dropdown:: Command Line Interface (CLI) + :icon: terminal + :animate: fade-in + + * Unified commands for training, fine-tuning, and evaluating models. + + * Modular options to customize tasks and configurations. + + * Learn more: :ref:`basics-cli` + +.. dropdown:: Assets + :icon: container + :animate: fade-in + + * Pre-trained models, datasets, and configurations bundled for easy reuse. + + * Simplified asset management for seamless integration. + + * Learn more: :ref:`basics-assets` + +.. dropdown:: Model Loader + :icon: package + :animate: fade-in + + * A robust mechanism for loading and saving models. + + * Supports checkpointing, versioning, and compatibility. + +.. dropdown:: Data Pipeline + :icon: database + :animate: fade-in + + * Extensible pipeline for preprocessing, augmentation, and batching. + + * Optimized for large-scale datasets and varied formats. + + * Learn more: :ref:`basics-data-pipeline` + + +.. dropdown:: Gang + :icon: server + :animate: fade-in + + * A novel abstraction for distributed tasks and parallel processing. + + * Simplifies workload distribution and resource management. + +Design Principles +----------------- + + +.. dropdown:: Dependency Inversion + :icon: infinity + :animate: fade-in + + * Decouples components to promote flexibility and testing. + + * Encourages the development of reusable modules. + + * Learn more: :ref:`basics-design-philosophy` + +.. dropdown:: Simplicity + :icon: check + :animate: fade-in + + * Emphasizes clear APIs and intuitive workflows. + + * Minimizes boilerplate code and unnecessary complexity. + +.. dropdown:: Modularity + :icon: plug + :animate: fade-in + + * Designed with modularity in mind. + + * Easy to add new datasets, models and trainers. + + * Learn more: :ref:`basics-runtime-extensions` + +.. dropdown:: Performance + :icon: flame + :animate: fade-in + + * Optimized for scalability and efficiency. + + * Supports state-of-the-art techniques for distributed training. + +.. dropdown:: Community-Centric + :icon: heart + :animate: fade-in + + * Active collaboration and contributions from the research community. + + * Comprehensive documentation and resources for onboarding. + + * Learn more: :ref:`faq-contributing` diff --git a/doc/source/basics/recipe.rst b/doc/source/basics/recipe.rst new file mode 100644 index 000000000..e1d959433 --- /dev/null +++ b/doc/source/basics/recipe.rst @@ -0,0 +1,78 @@ +.. _basics-recipe: + +:octicon:`gift` Recipes +======================= + +Instruction Fine-tuning Example +------------------------------- + +The recipe handler for instruction fine-tuning is registered at :meth:`fairseq2.recipes.lm._setup_cli`. +The recipe itself resides at :meth:`fairseq2.recipes.lm.load_instruction_finetuner`. +This recipe loads a set of components into a :class:`fairseq2.recipes.trainer.Trainer`, including: + +- Model +- Data reader (train and valid) +- Checkpoint manager +- Criterion +- Metric recorder + +The only inputs required are the configuration and an output directory for checkpoints, training events, and more. + +How to Configure a Recipe +------------------------- + +Recipe configuration is defined as a class, such as :class:`fairseq2.recipes.lm.instruction_finetune.InstructionFinetuneConfig` for instruction fine-tuning. + +Customized configurations can be loaded from YAML files. For example, if you create a YAML configuration file at ``$YOUR_CONFIG.yaml``: + +.. code-block:: yaml + + dataset: /data/gsm8k_data/sft + model: llama3_1_8b + max_num_tokens: 4096 + max_seq_len: 4096 + max_num_steps: 1000 + max_num_data_epochs: 20 + checkpoint_every_n_steps: 1000 + keep_last_n_checkpoints: 1 + keep_last_n_models: 1 + publish_metrics_every_n_steps: 5 + +This configuration can be loaded using the following command: + +.. code-block:: bash + + fairseq2 lm instruction_finetune --config-file $YOUR_CONFIG.yaml ... + +Key Points on Configuration +--------------------------- + +* **Using YAML Files:** + + * The YAML file path is provided as the argument value to ``--config-file``. + + * The YAML file does not need to represent the entire configuration. + + * You can dump the default preset configuration to view default values using the ``--dump-config`` argument with the recipe command. + + * Multiple ``--config-file`` arguments can be used, and configurations will be merged from left to right, with the last value overriding previous ones. + +* **Overriding via CLI:** + + * Configuration values can be adjusted directly via the CLI using ``--config k=v``. + + * For example: ``--config optimizer_config.lr=4e-5`` + + * Multiple key-value pairs can be passed: + + * Example: ``--config max_num_tokens=512 max_seq_len=512`` + + * CLI overrides can be combined with ``--config-file``, where CLI values will take precedence over YAML values. + +* **Using Add/Del Directives:** + + * Directives such as ``add`` or ``del`` can be used for more advanced overrides: + + * Add a new value: ``--config add:xyz=value`` + + * Remove a value: ``--config del:xyz`` diff --git a/doc/source/basics/runtime_extensions.rst b/doc/source/basics/runtime_extensions.rst new file mode 100644 index 000000000..a34b99af0 --- /dev/null +++ b/doc/source/basics/runtime_extensions.rst @@ -0,0 +1,141 @@ +.. _basics-runtime-extensions: + +:octicon:`plug` Runtime Extensions +================================== + +fairseq2 provides a flexible runtime extension system that allows you to extend its functionality without modifying the core codebase. This system leverages Python's setuptools entry points to dynamically load and register extensions during initialization. + +Overview +-------- + +The extension system is built around a dependency injection container (learn more in :ref:`basics-design-philosophy`) that manages fairseq2's components. +Through this system, you can: + +* Register new models +* Add custom asset providers +* Extend the runtime context +* Register custom tensor loaders/dumpers +* Add value converters +* And more... + +Basic Usage +----------- + +Before using any fairseq2 APIs, you must initialize the framework with :meth:`fairseq2.setup_fairseq2`: + +.. code-block:: python + + from fairseq2 import setup_fairseq2 + + setup_fairseq2() + +Creating Extensions +------------------- + +To create an extension, define a setup function: + +.. code-block:: python + + def setup_my_extension() -> None: + # Register your custom components here + pass + +Registering Extensions +---------------------- + +Extensions are registered using setuptools entry points. You can configure them in either ``setup.py`` or ``pyproject.toml``: + +Using setup.py: + +.. code-block:: python + + setup( + name="my-fairseq2-extension", + entry_points={ + "fairseq2": [ + "my_extension = my_package.module:setup_my_extension", + ], + }, + ) + +Using pyproject.toml: + +.. code-block:: toml + + [project.entry-points."fairseq2"] + my_extension = "my_package.module:setup_my_extension" + +Extension Loading Process +------------------------- + +When ``setup_fairseq2()`` is called, the following steps occur: + +1. fairseq2 components are initialized +2. All registered extensions are discovered via entry points +3. Each extension's setup function is called + +Complete Example +---------------- + +Here's a complete example of implementing a fairseq2 extension: + +.. code-block:: python + + from fairseq2.assets import default_asset_store + + def setup_my_extension() -> None: + default_asset_store.add_package_metadata_provider("my_package") + +Error Handling +-------------- + +The extension system includes error handling to maintain system stability: + +* Failed extensions log warnings by default +* Set ``FAIRSEQ2_EXTENSION_TRACE`` environment variable for detailed error traces +* Invalid extension functions raise ``RuntimeError`` + +.. code-block:: bash + + export FAIRSEQ2_EXTENSION_TRACE=1 + + +Best Practices +-------------- + +We suggest the following best practices for implementing extensions. + +Documentation +^^^^^^^^^^^^^ + +* Document your extension's functionality +* Specify requirements and dependencies +* Include usage examples + +Testing +^^^^^^^ + +* Test extensions in isolation +* Verify integration with fairseq2 +* Test error cases and edge conditions + +Error Handler +^^^^^^^^^^^^^ + +* Implement proper error handling +* Fail fast if required dependencies are missing +* Provide meaningful error messages + +Configuration +------------- + +Environment Variables +^^^^^^^^^^^^^^^^^^^^^ + +``FAIRSEQ2_EXTENSION_TRACE`` + Set this environment variable to enable detailed stack traces when extensions fail to load. + +See Also +-------- + +* :doc:`/reference/api/fairseq2.assets/index` diff --git a/doc/source/basics/trainer.rst b/doc/source/basics/trainer.rst new file mode 100644 index 000000000..368f8add3 --- /dev/null +++ b/doc/source/basics/trainer.rst @@ -0,0 +1,354 @@ +.. _basics-trainer: + +:octicon:`dependabot` Trainer +============================= + +The :class:`fairseq2.recipes.trainer.Trainer` class is the core class for training models. + +Overview +-------- + +The trainer in fairseq2 is designed to be flexible and model-agnostic, handling various training scenarios from simple models to complex distributed training setups. +It is probably the most complex system in fairseq2, but also the most powerful. + +.. mermaid:: + + flowchart LR + %% Main Trainer Class + A[Trainer] --> B[TrainUnit] + A --> C[DataReader] + A --> D[Optimizer] + A --> E[CheckpointManager] + A --> H[LRScheduler] + A --> I[Gang System] + A --> P[Metrics Logging] + A --> V[Validation] + + %% TrainUnit Components + B --> F[Model] + B --> G[MetricBag] + + %% Gang System + I --> J[Root Gang] + I --> K[DP Gang] + I --> L[TP Gang] + + %% Metrics Logging + P --> P1[TensorBoard] + P --> P2[WandB] + P --> P3[JSON Logger] + + %% Validation + V --> Q[EvalUnit] + V --> R[Validation DataReader] + + %% CheckpointManager Details + E --> E1[Save State] + E --> E2[Load State] + E --> E3[Keep Best Checkpoints] + E --> E4[Save FSDP Model] + + +Core Components +--------------- + +TrainUnit +^^^^^^^^^ + +The ``TrainUnit`` is an abstract class that encapsulates model-specific training logic: + +.. code-block:: python + + class TrainUnit(ABC, Generic[BatchT_contra]): + """Represents a unit to be used with Trainer.""" + + @abstractmethod + def __call__(self, batch: BatchT_contra) -> tuple[Tensor, int | None]: + """Process batch and return loss and number of targets.""" + + @abstractmethod + def set_step_nr(self, step_nr: int) -> None: + """Set current training step number.""" + + @property + @abstractmethod + def model(self) -> Module: + """The underlying model.""" + + @property + @abstractmethod + def metric_bag(self) -> MetricBag: + """Training-related metrics.""" + +.. dropdown:: Example implementation + :icon: code + :animate: fade-in + + .. code-block:: python + + class TransformerTrainUnit(TrainUnit[TransformerBatch]): + def __init__(self, model: TransformerModel) -> None: + super().__init__(model) + self._metric_bag = MetricBag() + self._metric_bag.register_metric("loss", Mean()) + + def __call__(self, batch: TransformerBatch) -> tuple[Tensor, int]: + outputs = self._model(**batch) + return outputs.loss, batch.num_tokens + +Trainer Configuration +^^^^^^^^^^^^^^^^^^^^^ + +The :class:`fairseq2.recipes.trainer.Trainer` class accepts a wide range of configuration options: + +.. code-block:: python + + # Example Trainer Configuration + trainer = Trainer( + # Basic parameters + unit=train_unit, # Training unit to compute loss + data_reader=data_reader, # Data reader for training batches + optimizer=optimizer, # Optimizer + checkpoint_manager=checkpoint_mgr, # Checkpoint manager + root_gang=root_gang, # Root gang for distributed training + + # Optional parameters + dp_gang=dp_gang, # Data parallel gang + tp_gang=tp_gang, # Tensor parallel gang + dtype=torch.float32, # Model data type + lr_scheduler=lr_scheduler, # Learning rate scheduler + max_num_steps=100_000, # Maximum training steps + max_num_data_epochs=10, # Maximum training epochs + + # Validation parameters + valid_units=[valid_unit], # Validation units + valid_data_readers=[valid_reader], # Validation data readers + validate_every_n_steps=1_000, # Validation frequency + + # Checkpoint parameters + checkpoint_every_n_steps=5_000, # Checkpoint frequency + keep_last_n_checkpoints=5, # Number of checkpoints to keep + + # Metric parameters + publish_metrics_every_n_steps=100, # Metric publishing frequency + tb_dir=Path("runs"), # TensorBoard directory + metrics_dir=Path("metrics"), # Metrics directory + ) + +Training Flow +------------- + +The training process follows this simplified sequence: + +.. mermaid:: + + sequenceDiagram + participant T as Trainer + participant U as TrainUnit + participant D as DataReader + participant M as Model + participant O as Optimizer + + T->>D: Request batch + D-->>T: Return batch + T->>U: Process batch + U->>M: Forward pass + M-->>U: Return loss + U-->>T: Return loss, num_targets + T->>M: Backward pass + T->>O: Update parameters + T->>T: Update metrics + +.. dropdown:: Step-by-step breakdown + :icon: code + :animate: fade-in + + We provide a simplified step-by-step process for the trainer in the following code snippet to help you understand the training flow. + + 1. **Initialization**: The trainer is initialized with the necessary components and configurations. + + .. code-block:: python + + def __init__(self, unit: TrainUnit[BatchT], data_reader: DataReader[BatchT], ...): + self._model = unit.model + self._unit = unit + self._data_reader = data_reader + # ... initialize other components + + 2. **Training Loop**: The training loop is implemented in the ``_do_run`` method: + + .. code-block:: python + + def _do_run(self) -> None: + while self._should_run_step(): + self._step_nr += 1 + + # Run training step + self._run_step() + + # Maybe validate + if self._should_validate(): + self._validate() + + # Maybe checkpoint + if self._should_checkpoint(): + self._checkpoint() + + 3. **Step Execution**: The ``_run_step`` method is responsible for executing a single training step: + + .. code-block:: python + + def _run_step(self) -> None: + # Collect batches + batches = self._next_batches() + + # Process each batch + for batch in batches: + # Forward pass + loss, num_targets = self._unit(batch) + + # Backward pass + self._loss_scaler.backward(loss) + + # Update parameters + self._loss_scaler.run_optimizer_step(self._step_nr) + + 4. **Validation**: The validation loop is implemented in the ``_validate`` method: + + .. code-block:: python + + def _validate(self) -> None: + log.info("Starting validation after step {}.", self._step_nr) + + self._model.eval() + + with summon_fsdp_for_validation(self._model): + unit_scores = [] + + for unit, data_reader in zip(self._valid_units, self._valid_data_readers): + unit_score = self._validate_unit(unit, data_reader) + if unit_score is not None: + unit_scores.append(unit_score) + + self._valid_score = self._compute_valid_score(unit_scores) + + self._model.train() + + log.info("Validation complete.") + + - Validation occurs at specified intervals (steps or epochs). + - `e.g.`: ``validate_every_n_steps`` or ``validate_every_n_data_epochs`` + - It computes a score (like accuracy) using :class:`fairseq2.recipes.evaluator.EvalUnit` objects and logs metrics. + - The validation score is compared to previous scores to: + - Save the best checkpoints. + - Stop early if performance stagnates. + + 5. **Checkpoint**: The checkpointing logic is implemented in the ``_checkpoint`` method: + + .. code-block:: python + + def _checkpoint(self) -> None: + # Save checkpoint + step_nr = self._step_nr + + self._checkpoint_manager.begin_checkpoint(step_nr) + + self._checkpoint_manager.save_state( + self.state_dict(), model_key="_model" + ) + + - The trainer saves checkpoints periodically at specified intervals (steps or epochs): + - `e.g.`: ``checkpoint_every_n_steps`` or ``checkpoint_every_n_data_epochs`` + - Both model weights and optimizer state are saved. + - Best-performing models are saved based on the validation score. + - `e.g.`: ``keep_best_n_checkpoints=3`` + - The checkpoint manager handles the checkpoint saving and loading, which ensures: + - Training can be resumed after interruptions. + - Best models are preserved for deployment. + + 6. **Metrics Logging**: The metrics logging logic is implemented in the ``_publish_metrics`` method: + + .. code-block:: python + + def _publish_metrics(self) -> None: + if self._tp_gang.rank == 0: + values = self._metric_bag.sync_and_compute_metrics() + record_metrics(self._metric_recorders, "train", values, self._step_nr) + + - The trainer supports multiple logging backends: + - TensorBoard: Visualize training curves + - ``tb_dir = Path("logs/tb")`` + - JSON Logs: Store metrics in files + - ``metrics_dir = Path("logs/metrics")`` + - Weights & Biases (WandB): Collaborative logging + - ``wandb_options = (Path("logs/wandb"), "project_name", "run_name")`` + + +Best Practices +-------------- + +#. **Metric Tracking**: + + - Register all relevant metrics in the train unit + - Use appropriate metric types (Mean, Sum, etc.) + - Consider adding validation metrics + +#. **Resource Management**: + + - Use appropriate batch sizes for your hardware + - Enable ``amp`` for memory efficiency + - Configure gradient accumulation as needed + +#. **Checkpoint Management**: + + - Save checkpoints regularly + - Implement proper cleanup strategy + +#. **Validation**: + + - Validate at appropriate intervals + - Track relevant validation metrics + - Implement early stopping if needed + + +Advanced Features +----------------- + +#. **Early Stopping**: + + .. code-block:: python + + def early_stopper(step_nr: int, score: float) -> bool: + # Custom early stopping logic + return score < threshold + + trainer = Trainer( + early_stopper=early_stopper, + score_metric_name="validation_loss", + lower_better=True, + ) + + +#. **Custom Learning Rate Scheduling**: + + .. code-block:: python + + class CustomLRScheduler(LRScheduler): + def get_lr(self) -> float: + # Custom LR calculation + return self.base_lr * decay_factor(self.step_nr) + + trainer = Trainer( + lr_scheduler=CustomLRScheduler(optimizer), + ) + + +#. **Profiling**: + + .. code-block:: python + + trainer = Trainer( + profile=(100, 10), # Skip 100 steps, profile 10 steps + tb_dir=Path("logs/tb"), # Save profiles to TensorBoard + ) + diff --git a/doc/source/conf.py b/doc/source/conf.py new file mode 100644 index 000000000..063748144 --- /dev/null +++ b/doc/source/conf.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import fairseq2n + +fairseq2n.DOC_MODE = True + +import fairseq2 + +# ------------------------------------------------------------ +# Project Information +# ------------------------------------------------------------ + +project = "fairseq2" +version = fairseq2.__version__ +release = fairseq2.__version__ +author = "Fundamental AI Research (FAIR) at Meta" + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + + +# ------------------------------------------------------------ +# General Configuration +# ------------------------------------------------------------ + +needs_sphinx = "7.4.0" + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosectionlabel", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinxcontrib.bibtex", + "sphinx_favicon", + "sphinx_design", + "sphinxcontrib.mermaid", + "myst_parser", +] + +myst_enable_extensions = ["colon_fence"] + +primary_domain = "py" + +highlight_language = "python3" + +autoclass_content = "both" +autodoc_class_signature = "mixed" +autodoc_default_options = { + "members": True, + "show-inheritance": True, +} +autodoc_member_order = "bysource" +autodoc_typehints = "description" +autodoc_typehints_description_target = "documented_params" +autodoc_typehints_format = "short" + +autosectionlabel_prefix_document = True + +todo_include_todos = True + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), + "torch": ("https://pytorch.org/docs/stable/", None), +} + +templates_path = ["templates"] + +bibtex_bibfiles = ["_static/bibliography.bib"] + +# ------------------------------------------------------------ +# HTML Output Options +# ------------------------------------------------------------ + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_title = project +html_theme = "furo" +html_logo = "_static/img/logo.png" + +html_theme_options = { + "light_css_variables": { + "color-brand-primary": "#008080", + "color-brand-content": "#008080", + }, + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/facebookresearch/fairseq2", + "html": """ + + + + """, + "class": "", + }, + ], +} +html_show_copyright = False +html_static_path = ["_static"] +html_title = "fairseq2 Documentation" + +favicons = [ + {"href": "img/logo.png"}, # => use `_static/img/logo.png` +] diff --git a/doc/source/faq/contributing.rst b/doc/source/faq/contributing.rst new file mode 100644 index 000000000..c5dea4191 --- /dev/null +++ b/doc/source/faq/contributing.rst @@ -0,0 +1,224 @@ +.. _faq-contributing: + + +:octicon:`heart` Contributing to fairseq2 +========================================= + +We want to make contributing to fairseq2 as easy as possible. Please make sure +to read this guideline carefully. + + +.. _faq-contributing-setup: + +Setting up Development Environment +---------------------------------- + +fairseq2 consists of two packages; the user-facing fairseq2 package implemented +in pure Python, and the fairseq2n package that contains the C++ and CUDA +portions of the library. If pre-built fairseq2n nightly packages are available +for your system (check [README](.#nightlies)), and if you are interested in only +modifying Python portions of fairseq2, you can use an editable pip installation +as described below. Otherwise, if you are planning to work on C++ or CUDA, or if +fairseq2n is not available as a pre-built package for your system, please follow +the installation instructions [here](INSTALL_FROM_SOURCE.md). + +For an editable installation, first, install a nightly build of fairseq2n (shown +for PyTorch ``2.4.0`` and variant ``cu121``): + +.. code-block:: sh + + pip install fairseq2n\ + --pre --upgrade --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.4.0/cu121 + + +.. warning:: + fairseq2n relies on the C++ API of PyTorch which has no API/ABI compatibility + between releases. This means **you have to install the fairseq2n variant that + exactly matches your PyTorch version**. Otherwise, you might experience issues + like immediate process crashes or spurious segfaults. For the same reason, if + you upgrade your PyTorch version, you must also upgrade your fairseq2n + installation. + +Then, clone the fairseq2 repository to your machine: + +.. code-block:: sh + + git clone https://github.com/facebookresearch/fairseq2.git + + cd fairseq2 + +And, install the fairseq2 package in editable mode: + +.. code-block:: sh + + pip install -e . + +Finally, make sure to install the development tools (e.g. linters and +formatters): + +.. code-block:: sh + + pip install -r requirements-devel.txt + +.. note:: + Any time you pull the latest fairseq2 commits from GitHub, make sure to re-run + the fairseq2n installation command above to get the most up-to-date binary. If + you observe runtime or test failures after the installation, it might be + because the latest nightlies are not published yet. If the problem persists + for more than 12 hours, please create a + [GitHub issue](https://github.com/facebookresearch/fairseq2/issues/new/choose). + + +.. _faq-contributing-testing: + +Testing Your Work +----------------- + +Any work that you plan to contribute should ideally be covered by a unit or +integration test. Once you have all your tests in place, ensure the full test +suite passes: + +.. code-block:: sh + + pytest + +By default, the tests will be run on CPU; pass the ``--device`` (short form +``-d``) option to run them on a specific device (e.g. GPU): + +.. code-block:: sh + + pytest --device cuda:0 + + +If you have changes in C++ or CUDA, in addition to ``pytest``, also run the +native tests: + +.. code-block:: sh + + native/build/tests/run-tests + + + +.. _faq-contributing-documenting: + +Documenting Your Work +--------------------- + +Any new or revised user-facing feature included in your work should have an +accompanying documentation. Depending on the scope of the work, the +documentation can be just docstrings in Python code, or, for larger features, +one or more Sphinx RST files. For docstrings, make sure to follow our formatting +conventions. You can check out any Python file in our code base to study how we +format our docstrings. + +To build and test out the library documentation, run the following commands: + +.. code-block:: sh + + cd doc + + pip install -r requirements.txt + + make html + + cd build/html + + python -m http.server 8084 + +and, visit `http://localhost:8084 `__ in your browser. + + +.. _faq-contributing-linting: + +Linting Your Work +----------------- + +If you have made changes to the Python code, run the following command and +address any issues reported: + +.. code-block:: sh + + mypy && flake8 . + + + +If you have touched C++ or CUDA files, lint your code with an up-to-date version +of the clang toolkit and address any issues reported: + +.. code-block:: sh + + cd native + + CC=clang CXX=clang++ cmake -GNinja -DFAIRSEQ2N_RUN_CLANG_TIDY=ON -B build + + cmake --build build + + +Alternatively: + +.. code-block:: sh + + cd native + + CC=clang CXX=clang++ cmake -GNinja -B build + + run-clang-tidy -p build + + +.. _faq-contributing-formatting: + +Formatting Your Work +-------------------- + +For Python code, run the following command: + +.. code-block:: sh + + isort . && black . + + +For C++ and CUDA, we do not enforce our coding conventions via a tool (e.g. +clang-format), but we expect you to follow them. You can check out any C++ file +in our code base to study our conventions. Since C++ syntax can become pretty +complex at times, refrain from being too pedantic and prioritize readability +over convention. + + +.. _faq_contributing_checklist: + +Check List for Pull Requests +---------------------------- + +1. Fork the repository and create your branch from ``main``. +2. If you've added code that should be tested, add tests, and ensure the entire + test suite passes. +3. If you've added or revised a user-facing feature, update the documentation. +4. Lint and format your code. +5. If you haven't already, complete the Contributor License Agreement ("CLA"). + + +.. _faq_contributing_cla: + +Contributor License Agreement +----------------------------- + +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open-source projects. + +Complete your CLA here: `https://code.facebook.com/cla `__ + + +.. _faq_contributing_issues: + +Issues +^^^^^^ + +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + + +License +^^^^^^^ + +By contributing to fairseq2, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/doc/source/getting_started/installation/index.rst b/doc/source/getting_started/installation/index.rst new file mode 100644 index 000000000..dd6cac636 --- /dev/null +++ b/doc/source/getting_started/installation/index.rst @@ -0,0 +1,238 @@ +.. _installation: + +================================ +:octicon:`download` Installation +================================ + +.. _installation_linux: + +Installing on Linux +------------------- + +Linux OS System Dependencies +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, fairseq2 is installed on Linux OS. fairseq2 depends on +`libsndfile `_, which can be installed +via the system package manager on most Linux distributions. For Ubuntu-based +systems, run: + +.. code-block:: sh + + sudo apt install libsndfile1 + +Similarly, on Fedora, run: + +.. code-block:: sh + + sudo dnf install libsndfile + +For other Linux distributions, please consult its documentation on how to +install packages. + +pip (Linux) +^^^^^^^^^^^ + +To install fairseq2 on Linux x86-64, run: + +.. code-block:: sh + + pip install fairseq2 + +This command will install a version of fairseq2 that is compatible with PyTorch +hosted on PyPI. + +At this time, we do not offer a pre-built package for ARM-based systems such as +Raspberry PI or NVIDIA Jetson. Please refer to :ref:`installation_from_source` to +learn how to build and install fairseq2 on those systems. + +Variants (Linux) +^^^^^^^^^^^^^^^^ + +Besides PyPI, fairseq2 also has pre-built packages available for different +PyTorch and CUDA versions hosted on FAIR's package repository. The following +matrix shows the supported combinations. + +.. list-table:: Supported Combinations + :header-rows: 1 + :widths: 15 15 20 20 10 + + * - fairseq2 + - PyTorch + - Python + - Variant* + - Arch + * - ``HEAD`` + - ``2.4.0`` + - ``>=3.10``, ``<=3.12`` + - ``cpu``, ``cu118``, ``cu121`` + - ``x86_64`` + * - ``HEAD`` + - ``2.3.0``, ``2.3.1`` + - ``>=3.10``, ``<=3.12`` + - ``cpu``, ``cu118``, ``cu121`` + - ``x86_64`` + * - ``HEAD`` + - ``2.2.0``, ``2.2.1``, ``2.2.2`` + - ``>=3.10``, ``<=3.12`` + - ``cpu``, ``cu118``, ``cu121`` + - ``x86_64`` + * - ``0.2.0`` + - ``2.1.1`` + - ``>=3.8``, ``<=3.11`` + - ``cpu``, ``cu118``, ``cu121`` + - ``x86_64`` + * - ``0.2.0`` + - ``2.0.1`` + - ``>=3.8``, ``<=3.11`` + - ``cpu``, ``cu117``, ``cu118`` + - ``x86_64`` + * - ``0.2.0`` + - ``1.13.1`` + - ``>=3.8``, ``<=3.10`` + - ``cpu``, ``cu116`` + - ``x86_64`` + +*\* cuXYZ refers to CUDA XY.Z (e.g. cu118 means CUDA 11.8)* + + +To install a specific combination, first follow the installation instructions on +`pytorch.org `_ for the desired PyTorch +version, and then use the following command (shown for PyTorch `2.4.0` and +variant `cu121`): + +.. code-block:: bash + + pip install fairseq2\ + --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.4.0/cu121 + +.. warning:: + + fairseq2 relies on the C++ API of PyTorch which has no API/ABI compatibility + between releases. This means **you have to install the fairseq2 variant that + exactly matches your PyTorch version**. Otherwise, you might experience issues + like immediate process crashes or spurious segfaults. For the same reason, if + you upgrade your PyTorch version, you must also upgrade your fairseq2 + installation. + +Nightlies +^^^^^^^^^ + +For Linux, we also host nightly builds on FAIR's package repository. The +supported variants are identical to the ones listed in *Variants* above. Once +you have installed the desired PyTorch version, you can use the following +command to install the corresponding nightly package (shown for PyTorch `2.4.0` +and variant `cu121`): + +.. code-block:: sh + + pip install fairseq2\ + --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.4.0/cu121 + +--- + +.. _installation_mac: + +Installing on macOS +------------------- + +macOS System Dependencies +^^^^^^^^^^^^^^^^^^^^^^^^^ + +fairseq2 depends on `libsndfile `__, +which can be installed via Homebrew: + +.. code-block:: sh + + brew install libsndfile + + +pip (macOS) +^^^^^^^^^^^ + +To install fairseq2 on ARM64-based (i.e. Apple silicon) Mac computers, run: + +.. code-block:: sh + + pip install fairseq2 + +This command will install a version of fairseq2 that is compatible with PyTorch +hosted on PyPI. + +At this time, we do not offer a pre-built package for Intel-based Mac computers. +Please refer to :ref:`installation_from_source` to learn how to build and +install fairseq2 on Intel machines. + +Variants (macOS) +^^^^^^^^^^^^^^^^ + +Besides PyPI, fairseq2 also has pre-built packages available for different +PyTorch versions hosted on FAIR's package repository. The following matrix shows +the supported combinations. + +.. list-table:: Supported Combinations + :header-rows: 1 + :widths: 10 10 10 10 + + * - fairseq2 + - PyTorch + - Python + - Arch + * - ``HEAD`` + - ``2.4.0`` + - ``>=3.9``, ``<=3.12`` + - ``arm64`` + +To install a specific combination, first follow the installation instructions on +`pytorch.org `__ for the desired PyTorch +version, and then use the following command (shown for PyTorch `2.4.0`): + +.. code-block:: sh + + pip install fairseq2\ + --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.4.0/cpu + + +.. warning:: + + fairseq2 relies on the C++ API of PyTorch which has no API/ABI compatibility + between releases. This means **you have to install the fairseq2 variant that + exactly matches your PyTorch version**. Otherwise, you might experience + issues like immediate process crashes or spurious segfaults. For the same + reason, if you upgrade your PyTorch version, you must also upgrade your + fairseq2 installation. + +Nightlies (macOS) +^^^^^^^^^^^^^^^^^ + +For macOS, we also host nightly builds on FAIR's package repository. The +supported variants are identical to the ones listed in *Variants* above. Once +you have installed the desired PyTorch version, you can use the following +command to install the corresponding nightly package (shown for PyTorch `2.4.0`): + +.. code-block:: sh + + pip install fairseq2\ + --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.4.0/cpu + +--- + +.. _installation_windows: + +Installing on Windows +--------------------- + +fairseq2 does not have native support for Windows and there are no plans to +support it in the foreseeable future. However, you can use fairseq2 via the +`Windows Subsystem for Linux `__ +(a.k.a. WSL) along with full CUDA support introduced in WSL 2. Please follow the +instructions in the :ref:`installation` section for a WSL-based +installation. + + +.. toctree:: + :maxdepth: 1 + :caption: Other Installation Guides + + installation_from_source + diff --git a/doc/source/getting_started/installation/installation_from_source.rst b/doc/source/getting_started/installation/installation_from_source.rst new file mode 100644 index 000000000..2c61ccfb2 --- /dev/null +++ b/doc/source/getting_started/installation/installation_from_source.rst @@ -0,0 +1,244 @@ +.. _installation_from_source: + +:octicon:`file-binary` Installing from Source +============================================= + +The instructions in this document are for users who want to use fairseq2 on a +system for which no pre-built fairseq2 package is available, or for users who +want to work on the C++/CUDA code of fairseq2. + +.. note:: + If you plan to edit and only modify Python portions of fairseq2, and if + fairseq2 provides a pre-built nightly package for your system, we recommend + using an editable pip installation as described in + :ref:`faq-contributing-setup`. + +1. Clone the Repository +----------------------- + +As first step, clone the fairseq2 Git repository to your machine: + +.. code-block:: sh + + git clone --recurse-submodules https://github.com/facebookresearch/fairseq2.git + +Note the ``--recurse-submodules`` option that asks Git to clone the third-party +dependencies along with fairseq2. If you have already cloned fairseq2 without +``--recurse-submodules`` before reading these instructions, you can run the +following command in your cloned repository to achieve the same effect: + +.. code-block:: sh + + git submodule update --init --recursive + +2. Set up a Python Virtual Environment +-------------------------------------- + +In simplest case, you can run the following command to create an empty Python +virtual environment (shown for Python 3.8): + +.. code-block:: sh + + python3.8 -m venv ~/myvenv + +And, activate it: + +.. code-block:: sh + + source ~/myvenv/bin/activate + +You can check out the +`Python documentation `_ +to learn more about other environment options. + +.. important:: + We strongly recommend creating a new environment from scratch instead of + reusing an existing one to avoid dependency conflicts. + +.. important:: + Manually building fairseq2 or any other C++ project in a Conda environment can + become tricky and fail due to environment-specific conflicts with the host + system libraries. Unless necessary, we recommend using a Python virtual + environment to build fairseq2. + +3. Install Dependencies +----------------------- + +3.1 System Dependencies +^^^^^^^^^^^^^^^^^^^^^^^ + +fairseq2 depends on `libsndfile `__, +which can be installed via the system package manager on most Linux +distributions, or via Homebrew on macOS. + +For Ubuntu-based systems, run: + +.. code-block:: sh + + sudo apt install libsndfile-dev + +Similarly, on Fedora, run: + +.. code-block:: sh + + sudo dnf install libsndfile-devel + +For other Linux distributions, please consult its documentation on how to +install packages. + +For macOS, you can use Homebrew: + +.. code-block:: sh + + brew install libsndfile + +3.2 PyTorch +^^^^^^^^^^^ + +Follow the instructions on `pytorch.org `_ +to install the desired PyTorch version. Make sure that the version you install +is supported by fairseq2. + +3.3 CUDA +^^^^^^^^ + +If you plan to build fairseq2 in a CUDA environment, you first have to install +a version of the CUDA Toolkit that matches the CUDA version of PyTorch. The +instructions for different toolkit versions can be found on NVIDIA’s website. + +.. note:: + If you are on a compute cluster with ``module`` support (e.g. FAIR Cluster), + you can typically activate a specific CUDA Toolkit version by + ``module load cuda/``. + +3.4 pip +^^^^^^^ + +Finally, to install fairseq2’s C++ build dependencies (e.g. cmake, ninja), use: + +.. code-block:: sh + + pip install -r native/python/requirements-build.txt + +4. Build fairseq2n +------------------ + +4.1 CPU-Only Builds +^^^^^^^^^^^^^^^^^^^ + +The final step before installing fairseq2 is to build fairseq2n, fairseq2’s C++ +library. Run the following command at the root directory of your repository to +configure the build: + +.. code-block:: sh + + cd native + + cmake -GNinja -B build + +Once the configuration step is complete, build fairseq2n using: + +.. code-block:: sh + + cmake --build build + +fairseq2 uses reasonable defaults, so the command above is sufficient for a +standard installation; however, if you are familiar with CMake, you can check +out the advanced build options in +`native/CMakeLists.txt `__. + +4.2 CUDA Builds +^^^^^^^^^^^^^^^ +.. note:: + If you are on a compute cluster with ``module`` support (e.g. FAIR Cluster), + you can typically activate a specific CUDA Toolkit version by + ``module load cuda/``. + +If you would like to build fairseq2’s CUDA kernels, set the ``FAIRSEQ2N_USE_CUDA`` +option to ``ON``. When turned on, the version of the CUDA Toolkit installed on +your machine and the version of CUDA that was used to build PyTorch must match: + +.. code-block:: sh + + cmake -GNinja -DFAIRSEQ2N_USE_CUDA=ON -B build + +Similar to CPU-only build, follow this command with: + +.. code-block:: sh + + cmake --build build + +4.3 CUDA Architectures +^^^^^^^^^^^^^^^^^^^^^^ + +By default, fairseq2 builds its CUDA kernels only for the Volta architecture. +You can override this setting using the ``CMAKE_CUDA_ARCHITECTURES`` option. +For +instance, the following configuration generates binary and PTX codes for the +Ampere architecture (e.g. for A100): + +.. code-block:: sh + + cmake -GNinja -DCMAKE_CUDA_ARCHITECTURES="80-real;80-virtual" -DFAIRSEQ2N_USE_CUDA=ON -B build + +5. Install fairseq2 +------------------- + +Once you have built fairseq2n, the actual Python package installation is +straightforward. First install fairseq2n: + +.. code-block:: sh + + cd native/python + + pip install . + + cd - + +Then, fairseq2: + +.. code-block:: sh + + pip install . + +5.1 Editable Install +^^^^^^^^^^^^^^^^^^^^ + +In case you want to modify and test fairseq2, installing it in editable mode +will be more convenient: + +.. code-block:: sh + + cd native/python + + pip install -e . + + cd - + + pip install -e . + +Optionally, you can also install the development tools (e.g. linters, +formatters) if you plan to contribute to fairseq2. See +:ref:`faq-contributing` for more information: + +.. code-block:: sh + + pip install -r requirements-devel.txt + +6. Optional Sanity Check +^^^^^^^^^^^^^^^^^^^^^^^^ + +To make sure that your installation has no issues, you can run the test suite: + +.. code-block:: sh + + pip install -r requirements-devel.txt + + pytest + +By default, the tests will be run on CPU; pass the ``--device`` (short form +``-d``) option to run them on a specific device (e.g. GPU): + +.. code-block:: sh + + pytest --device cuda:0 \ No newline at end of file diff --git a/doc/source/getting_started/quick_start.rst b/doc/source/getting_started/quick_start.rst new file mode 100644 index 000000000..11c2649f6 --- /dev/null +++ b/doc/source/getting_started/quick_start.rst @@ -0,0 +1,53 @@ +.. _quick_start: + +========================================== +:octicon:`rocket` Quick Start +========================================== + +Language Model (LM) +------------------- + +Supervised Fine-Tuning (SFT) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + fairseq2 lm instruction_finetune $OUTPUT_DIR --config \ + dataset=/datasets/facebook/fairseq2-lm-gsm8k/sft \ + model=llama3_2_1b \ + max_num_tokens=4096 \ + dtype=float16 \ + max_num_steps=1000 \ + max_num_data_epochs=20 \ + checkpoint_every_n_steps=1000 + +Read more about this recipe in :ref:`tutorial-end-to-end-fine-tuning`. + + +Generating Text +^^^^^^^^^^^^^^^ + +After fine-tuning a language model, you can generate text with the following command: + +.. code-block:: bash + + CKPT_PATH="/checkpoint/$USER/experiments/$EXPERIMENT_NAME/checkpoints/step_1000" + CKPT_DIR=$(dirname "$CKPT_PATH") + CKPT="checkpoint_$(basename "$CKPT_DIR")" # e.g., checkpoint_step_1000 + SAVE_DIR="$CKPT_DIR/generation" + DATASET="/datasets/facebook/fairseq2-lm-gsm8k/test/test.jsonl" + + fairseq2 lm generate $SAVE_DIR --no-sweep-dir --config \ + checkpoint_dir=$CKPT_DIR \ + model=$CKPT \ + generator_config.temperature=0.1 \ + dataset=$DATASET + + +See Also +-------- + +- :doc:`Design Philosophy ` +- :doc:`Recipe ` +- :doc:`CLI ` +- :doc:`Assets ` diff --git a/doc/source/index.rst b/doc/source/index.rst new file mode 100644 index 000000000..78a303f0a --- /dev/null +++ b/doc/source/index.rst @@ -0,0 +1,85 @@ +:github_url: https://github.com/facebookresearch/fairseq2 + +================================= +Welcome to fairseq2 Documentation +================================= + +fairseq2 is a sequence modeling toolkit that allows researchers and developers +to train custom models for translation, summarization, language modeling, and +other content generation tasks. + +.. grid:: 3 + + .. grid-item-card:: Quick Start + :link: tutorial-end-to-end-fine-tuning + :link-type: ref + + Run a quick start tutorial. + + .. grid-item-card:: Basics + :link: basics-overview + :link-type: ref + + Get familiar with fairseq2. + + .. grid-item-card:: API Reference + :link: reference-api + :link-type: ref + + Jump to the code. + + +Documentation +------------- + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + + getting_started/installation/index + getting_started/quick_start + +.. toctree:: + :maxdepth: 1 + :caption: Basics + + basics/overview + basics/design_philosophy + basics/cli + basics/assets + basics/data_pipeline + basics/ckpt + basics/recipe + basics/runtime_extensions + basics/gang + basics/trainer + +.. toctree:: + :maxdepth: 1 + :caption: Tutorials + + tutorials/end_to_end_fine_tuning + tutorials/monitor_your_experiments + tutorials/presets + tutorials/benchmarking + tutorials/pudb + tutorials/models + +.. toctree:: + :maxdepth: 1 + :caption: FAQ + + faq/contributing + +.. toctree:: + :maxdepth: 1 + :caption: Reference + + reference/api/index + reference/bibliography + +Indices and tables +------------------ + +* :ref:`genindex` +* :ref:`modindex` \ No newline at end of file diff --git a/doc/source/reference/api/fairseq2.assets/index.rst b/doc/source/reference/api/fairseq2.assets/index.rst new file mode 100644 index 000000000..13ed6d932 --- /dev/null +++ b/doc/source/reference/api/fairseq2.assets/index.rst @@ -0,0 +1,17 @@ +fairseq2.assets +=============== + +.. currentmodule:: fairseq2.assets + +.. autoclasstree:: fairseq2.assets + :full: + :zoom: + +Classes +------- + +.. autoclass:: AssetStore + +.. autoclass:: AssetMetadataProvider + +.. autoclass:: AssetCard \ No newline at end of file diff --git a/doc/source/reference/api/fairseq2.data/data_pipeline.rst b/doc/source/reference/api/fairseq2.data/data_pipeline.rst new file mode 100644 index 000000000..11178eb7f --- /dev/null +++ b/doc/source/reference/api/fairseq2.data/data_pipeline.rst @@ -0,0 +1,167 @@ +fairseq2.data.data_pipeline +=========================== + +.. currentmodule:: fairseq2.data.data_pipeline + +The data pipeline module provides the core data processing functionality in fairseq2. + +Classes +------- + +.. autoclass:: DataPipeline + :members: + :special-members: __iter__ + +.. autoclass:: DataPipelineBuilder + :members: + +.. autoclass:: Collater + :members: + :special-members: __call__ + +.. autoclass:: CollateOptionsOverride + :members: + :special-members: __init__ + +.. autoclass:: FileMapper + :members: + :special-members: __init__, __call__ + +.. autoclass:: SequenceData + :members: + +.. autoclass:: FileMapperOutput + :members: + +Functions +--------- + +.. autofunction:: create_bucket_sizes + +.. autofunction:: get_last_failed_example + +.. autofunction:: list_files + +.. autofunction:: read_sequence + +.. autofunction:: read_zipped_records + +.. autofunction:: read_iterator + +Exceptions +---------- + +.. autoclass:: DataPipelineError + :members: + +.. autoclass:: ByteStreamError + :members: + +.. autoclass:: RecordError + :members: + +Examples +-------- + +Creating a Basic Pipeline +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from fairseq2.data import read_sequence, DataPipeline + + # Create a simple pipeline that processes numbers + pipeline = ( + read_sequence([1, 2, 3, 4, 5]) + .map(lambda x: x * 2) + .filter(lambda x: x > 5) + .and_return() + ) + + # Iterate over the results + for item in pipeline: + print(item) # Outputs: 6, 8, 10 + +Using Column Selection +^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Process structured data with column selection + data = [ + {"text": "Hello", "label": 1}, + {"text": "World", "label": 0} + ] + + pipeline = ( + read_sequence(data) + .map(lambda x: x.upper(), selector="text") + .and_return() + ) + + # Results will have uppercase text but unchanged labels + # [{"text": "HELLO", "label": 1}, {"text": "WORLD", "label": 0}] + +Combining Pipelines +^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Create two pipelines + p1 = read_sequence([1, 2, 3]).and_return() + p2 = read_sequence(['a', 'b', 'c']).and_return() + + # Zip them together with names + combined = DataPipeline.zip( + [p1, p2], + names=["numbers", "letters"] + ).and_return() + + # Results: [ + # {"numbers": 1, "letters": 'a'}, + # {"numbers": 2, "letters": 'b'}, + # {"numbers": 3, "letters": 'c'} + # ] + +Using Bucketing +^^^^^^^^^^^^^^^ + +.. code-block:: python + + from fairseq2.data import create_bucket_sizes + + # Create optimal bucket sizes for sequence processing + bucket_sizes = create_bucket_sizes( + max_num_elements=1024, + max_seq_len=128, + min_seq_len=1, + num_seqs_multiple_of=8 + ) + + # Use bucketing in a pipeline + pipeline = ( + read_sequence(data) + .bucket_by_length( + bucket_sizes, + selector="text", + drop_remainder=False + ) + .and_return() + ) + +State Management +^^^^^^^^^^^^^^^^ + +.. code-block:: python + + # Save pipeline state + state = pipeline.state_dict() + + # Load pipeline state + new_pipeline = create_pipeline() # Create a new pipeline + new_pipeline.load_state_dict(state) # Restore the state + +See Also +-------- + +* :ref:`basics-data-pipeline` - Basic introduction to data pipeline diff --git a/doc/source/reference/api/fairseq2.data/index.rst b/doc/source/reference/api/fairseq2.data/index.rst new file mode 100644 index 000000000..6023a88af --- /dev/null +++ b/doc/source/reference/api/fairseq2.data/index.rst @@ -0,0 +1,17 @@ +============= +fairseq2.data +============= + +.. module:: fairseq2.data + +This module contains data pipeline operators. + +.. autoclasstree:: fairseq2.data + :full: + :zoom: + +.. toctree:: + :maxdepth: 1 + + text/index + data_pipeline diff --git a/doc/source/reference/api/fairseq2.data/text/index.rst b/doc/source/reference/api/fairseq2.data/text/index.rst new file mode 100644 index 000000000..27db2d9e5 --- /dev/null +++ b/doc/source/reference/api/fairseq2.data/text/index.rst @@ -0,0 +1,16 @@ +================== +fairseq2.data.text +================== + +.. module:: fairseq2.data.text + +This module contains text tokenizers and text specific data pipeline operators. + +.. autoclasstree:: fairseq2.data.text + :full: + :zoom: + +.. toctree:: + :maxdepth: 1 + + text_tokenizers diff --git a/doc/source/reference/api/fairseq2.data/text/text_tokenizers.rst b/doc/source/reference/api/fairseq2.data/text/text_tokenizers.rst new file mode 100644 index 000000000..7054520ec --- /dev/null +++ b/doc/source/reference/api/fairseq2.data/text/text_tokenizers.rst @@ -0,0 +1,3 @@ +=============== +Text Tokenizers +=============== diff --git a/doc/source/reference/api/fairseq2.datasets/index.rst b/doc/source/reference/api/fairseq2.datasets/index.rst new file mode 100644 index 000000000..4cf73b9d5 --- /dev/null +++ b/doc/source/reference/api/fairseq2.datasets/index.rst @@ -0,0 +1,62 @@ +================= +fairseq2.datasets +================= + +.. module:: fairseq2.datasets + +=============== +Dataset Loaders +=============== + +The dataset loader system in fairseq2 provides a flexible and extensible way to load different types of datasets. +The system uses the concept of dataset families to organize and manage different dataset formats. + +Dataset Family +-------------- + +A dataset family represents a specific format or structure of data that requires specialized loading logic. +Each dataset is associated with a family through the ``dataset_family`` field in its asset card. + +Built-in Dataset Families +^^^^^^^^^^^^^^^^^^^^^^^^^ + +fairseq2 includes several built-in dataset families: + +- ``generic_text``: For plain text datasets +- ``generic_parallel_text``: For parallel text/translation datasets +- ``generic_asr``: For automatic speech recognition datasets +- ``generic_speech``: For speech-only datasets +- ``generic_instruction``: For instruction-tuning datasets +- ``generic_preference_optimization``: For preference optimization datasets + +Example Asset Card +^^^^^^^^^^^^^^^^^^ + +.. code-block:: yaml + + name: librispeech_asr + dataset_family: generic_asr + tokenizer: "https://example.com/tokenizer.model" + tokenizer_family: char_tokenizer + +Usage Examples +-------------- + +Loading a Dataset Using Family +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from fairseq2.datasets import load_text_dataset + + # Load using dataset name (will look up asset card) + dataset = load_text_dataset("my_text_dataset") + + # Load using explicit asset card + card = AssetCard(name="custom_dataset", dataset_family="generic_text") + dataset = load_text_dataset(card) + +See Also +-------- + +- :doc:`Text Dataset ` diff --git a/doc/source/reference/api/fairseq2.nn/index.rst b/doc/source/reference/api/fairseq2.nn/index.rst new file mode 100644 index 000000000..f81938f0d --- /dev/null +++ b/doc/source/reference/api/fairseq2.nn/index.rst @@ -0,0 +1,19 @@ +=========== +fairseq2.nn +=========== + +.. module:: fairseq2.nn + +This module contains various PyTorch modules and related APIs to help with +building new model architectures. It follows similar conventions to :mod:`torch.nn` +and can be considered an addendum to it. + +.. autoclasstree:: fairseq2.nn + :full: + :zoom: + + +.. toctree:: + :maxdepth: 1 + + position_encoders diff --git a/doc/source/reference/api/fairseq2.nn/position_encoders.rst b/doc/source/reference/api/fairseq2.nn/position_encoders.rst new file mode 100644 index 000000000..cec7b4637 --- /dev/null +++ b/doc/source/reference/api/fairseq2.nn/position_encoders.rst @@ -0,0 +1,100 @@ +============================ +Position Encoders +============================ + +.. currentmodule:: fairseq2.nn + +A set of PyTorch modules to encode sequences with positional information. + +**ABCs** + +* :class:`PositionEncoder` + +**Classes** + +* :class:`SinusoidalPositionEncoder` +* :class:`LearnedPositionEncoder` +* :class:`RotaryEncoder` + +ABCs +==== + +.. autoclass:: PositionEncoder + + .. autoclasstree:: fairseq2.nn.PositionEncoder fairseq2.nn.SinusoidalPositionEncoder fairseq2.nn.LearnedPositionEncoder fairseq2.nn.RotaryEncoder + :full: + +Classes +======= + + +.. autoclass:: SinusoidalPositionEncoder + + + .. autoclasstree:: fairseq2.nn.SinusoidalPositionEncoder + :full: + + The positional encodings are initialized as in tensor2tensor which differs + slightly from the description in section 3.5 of + :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. This means instead of: + + .. math:: + + PE_{(pos, 2i)} = \text{sin}(pos/10000^{2i/d_{model}}) + + PE_{(pos, 2i+1)} = \text{cos}(pos/10000^{2i/d_{model}}) + + we use: + + .. math:: + + PE_{(pos, i)} = \text{sin}(pos/10000^{i/d_{model}})\;\text{for}\;i\; <\frac{d_{model}}{2} + + PE_{(pos, i)} = \text{cos}(pos/10000^{i/d_{model}})\;\text{for}\;i\;\geq\frac{d_{model}}{2} + + See `here `_ for more + information. + + Usage: + + >>> import torch + >>> + >>> from fairseq2.nn.position_encoder import SinusoidalPositionEncoder + >>> + >>> m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=16) + >>> + >>> seqs = torch.ones((3, 4)) + >>> + >>> m(seqs) + tensor([[ 1.0000e+00, 1.0000e+00, 2.0000e+00, 2.0000e+00], # pos 0 + [ 9.4147e-01, 2.0000e-04, 6.4030e-01, 2.0000e+00], # pos 1 + [ 1.0930e-02, 3.0000e-04, -5.1615e-01, 2.0000e+00]]) # pos 2 + + +.. autoclass:: LearnedPositionEncoder + + .. autoclasstree:: fairseq2.nn.LearnedPositionEncoder + :full: + + Usage: + + >>> import torch + >>> + >>> from fairseq2.nn.position_encoder import LearnedPositionEncoder + >>> + >>> m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=16) + >>> + >>> seqs = torch.ones((3, 4)) + >>> + >>> m(seqs) + tensor([[ 1.1135, 0.5548, 0.4293, 2.0112], # pos 0 + [ 0.2364, 0.6009, 3.3865, -2.4810], # pos 1 + [-0.4746, 0.4544, 0.2761, 0.8828]], grad_fn=) # pos 2 + + + + +.. autoclass:: RotaryEncoder + + .. autoclasstree:: fairseq2.nn.RotaryEncoder + :full: diff --git a/doc/source/reference/api/fairseq2.recipes/cli.rst b/doc/source/reference/api/fairseq2.recipes/cli.rst new file mode 100644 index 000000000..c726ea63a --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/cli.rst @@ -0,0 +1,127 @@ +==================== +fairseq2.recipes.cli +==================== + +.. currentmodule:: fairseq2.recipes.cli + +Classes +------- + +.. autoclass:: Cli + :members: + +.. autoclass:: CliGroup + :members: + +.. autoclass:: CliCommand + :members: + +.. autoclass:: CliCommandHandler + :members: + +.. autoclass:: RecipeCommandHandler + :members: + +Examples +-------- + +Creating a Custom CLI +===================== + +To create a custom CLI, you'll need to: + +1. Create a CLI group +2. Add commands to the group +3. Register your CLI extension + +Here's a complete example: + +.. code-block:: python + + from fairseq2.recipes.cli import Cli, CliCommandHandler, RecipeCommandHandler + + def setup_custom_cli(cli: Cli) -> None: + # Create a new command group + group = cli.add_group( + "custom", + help="Custom recipes and utilities" + ) + + # Add a command using RecipeCommandHandler + custom_handler = RecipeCommandHandler( + loader=load_custom_recipe, # this is the recipe entrypoint callback function + preset_configs=custom_presets, # this is the preset configs registry + default_preset="default", # this is the default preset name + sweep_allowed_keys=["model", "dataset"] # Optional + ) + + group.add_command( + name="custom_command", + handler=custom_handler, + help="Run custom recipe" + ) + +You can find more examples in our recipe examples: + +* :mod:`fairseq2.recipes.lm.instruction_finetune` +* :mod:`fairseq2.recipes.llama.convert_checkpoint` +* :mod:`fairseq2.recipes.wav2vec2.train` + +Recipe Command Handler +====================== + +The :class:`RecipeCommandHandler` class provides a standardized way to handle recipe commands. It automatically sets up: + +- Configuration management (presets, files, overrides) +- Output directory handling +- Logging setup +- Environment setup for distributed training + +Example implementation: + +.. code-block:: python + + from dataclasses import dataclass + from pathlib import Path + from typing import Callable + + @dataclass + class CustomConfig: + param1: str + param2: int + + def load_custom_recipe(config: CustomConfig, output_dir: Path) -> Callable[[], None]: + def run_recipe() -> None: + # Recipe implementation + pass + + return run_recipe + + # Create preset configs + custom_presets = ConfigRegistry(CustomConfig) + custom_presets.register("default", CustomConfig(param1="value", param2=42)) + +CLI Initialization Process +========================== + +The CLI system is initialized in the following order: + +1. :class:`Cli` instance is created in :meth:`fairseq2.recipes.main` +2. Core CLI groups are registered in :meth:`fairseq2.recipes._setup_cli` +3. Extension CLI groups are registered via :meth:`fairseq2.recipes._setup_cli_extensions` + +To add your own CLI extension: + +1. Create a Python package for your extension +2. Create an entry point in your package's ``setup.py``: + +.. code-block:: python + + setup( + name="my_fairseq2_extension", + entry_points={ + "fairseq2.cli": [ + "custom = my_extension.cli:setup_custom_cli" + ] + } + ) diff --git a/doc/source/reference/api/fairseq2.recipes/index.rst b/doc/source/reference/api/fairseq2.recipes/index.rst new file mode 100644 index 000000000..2d2a071e3 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/index.rst @@ -0,0 +1,17 @@ +================ +fairseq2.recipes +================ + +.. module:: fairseq2.recipes + +This module contains various recipes for training, evaluation, and inference of +different model architectures. + +.. toctree:: + :maxdepth: 1 + + lm/index + llama/index + wav2vec2/index + cli + trainer diff --git a/doc/source/reference/api/fairseq2.recipes/llama/convert_checkpoint.rst b/doc/source/reference/api/fairseq2.recipes/llama/convert_checkpoint.rst new file mode 100644 index 000000000..3237a56c9 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/llama/convert_checkpoint.rst @@ -0,0 +1,96 @@ +.. _reference-recipes-lm-convert-checkpoint: + +=================== +Convert Checkpoints +=================== + +.. module:: fairseq2.recipes.llama.convert_checkpoint + +The checkpoint conversion module provides utilities to convert fairseq2 model checkpoints to different formats for interoperability with other frameworks. + +Command Line Interface +---------------------- + +.. code-block:: bash + + fairseq2 llama convert_checkpoint --model + +Arguments +^^^^^^^^^ + +- ``--model ``: The model architecture name (e.g., ``llama3_2_1b``) to generate correct ``params.json`` +- ````: Directory containing the fairseq2 checkpoint (model.pt or model.{0,1,2...}.pt for sharded checkpoints) +- ````: Output directory to store the converted checkpoint + +Supported Architectures +----------------------- + +The converter supports various LLaMA architectures including: + +- LLaMA 1: 7B, 13B, 33B, 65B +- LLaMA 2: 7B, 13B, 70B +- LLaMA 3: 8B, 70B +- LLaMA 3.1: 8B, 70B +- LLaMA 3.2: 1B, 3B + +For the complete list of architectures and their configurations, see :mod:`fairseq2.models.llama.archs`. + +Output Format +------------- + +The converter produces: + +1. Model weights in the reference format: + - Single checkpoint: ``consolidated.00.pth`` + - Sharded checkpoints: ``consolidated.{00,01,02...}.pth`` + +2. ``params.json`` containing model configuration: + +.. code-block:: json + + { + "model": { + "dim": 2048, // Model dimension + "n_layers": 16, // Number of layers + "n_heads": 32, // Number of attention heads + "n_kv_heads": 8, // Number of key/value heads (if different from n_heads) + "multiple_of": 256, // FFN dimension multiple + "ffn_dim_multiplier": 1.5, // FFN dimension multiplier (if not 1.0) + "rope_theta": 500000.0, // RoPE theta value + "norm_eps": 1e-5 // Layer norm epsilon + } + } + +Usage Example +------------- + +1. Convert a fairseq2 checkpoint to reference format: + +.. code-block:: bash + + fairseq2 llama convert_checkpoint --model llama3_2_1b \ + /path/to/fairseq2/checkpoint \ + /path/to/output/dir + +2. Convert to HuggingFace format: + +After converting to reference format, use the HuggingFace conversion script to convert to HF format: + +.. code-block:: bash + + python -m transformers.models.llama.convert_llama_weights_to_hf \ + --input_dir /path/to/output/dir \ + --model_size 1B \ + --output_dir /path/to/hf/model + +API Details +----------- + +.. autoclass:: ConvertCheckpointCommandHandler + +See Also +-------- + +- :doc:`End-to-End Fine-Tuning Tutorial ` +- :class:`fairseq2.models.llama.factory.LLaMAConfig` +- :class:`fairseq2.models.llama.archs` \ No newline at end of file diff --git a/doc/source/reference/api/fairseq2.recipes/llama/index.rst b/doc/source/reference/api/fairseq2.recipes/llama/index.rst new file mode 100644 index 000000000..def35d4cf --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/llama/index.rst @@ -0,0 +1,18 @@ +====================== +fairseq2.recipes.llama +====================== + +.. module:: fairseq2.recipes.llama + + +.. autoclasstree:: fairseq2.recipes.llama + :full: + :zoom: + +See Also +======== + +.. toctree:: + :maxdepth: 1 + + convert_checkpoint diff --git a/doc/source/reference/api/fairseq2.recipes/lm/index.rst b/doc/source/reference/api/fairseq2.recipes/lm/index.rst new file mode 100644 index 000000000..b20229996 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/lm/index.rst @@ -0,0 +1,32 @@ +=================== +fairseq2.recipes.lm +=================== + +.. module:: fairseq2.recipes.lm + +Overview +======== +The ``fairseq2.recipes.lm`` module provides utilities and recipes for language model training and fine-tuning. +This includes tools for both pre-training and instruction tuning of language models. + +Key Features +============ +- Language model pre-training utilities +- Instruction fine-tuning support +- CLI setup for language model training +- Common training recipes and configurations + +Submodules +========== + +.. toctree:: + :maxdepth: 1 + + instruction_finetune + +The ``instruction_finetune`` module provides specialized utilities for instruction-based fine-tuning of language models. + +Usage Examples +============== + +- :ref:`tutorial-end-to-end-fine-tuning` diff --git a/doc/source/reference/api/fairseq2.recipes/lm/instruction_finetune.rst b/doc/source/reference/api/fairseq2.recipes/lm/instruction_finetune.rst new file mode 100644 index 000000000..c5500b359 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/lm/instruction_finetune.rst @@ -0,0 +1,20 @@ +======================================== +fairseq2.recipes.lm.instruction_finetune +======================================== + + +.. currentmodule:: fairseq2.recipes.lm.instruction_finetune + +.. autoclasstree:: fairseq2.recipes.lm + :full: + :zoom: + +Classes +======= + +.. autoclass:: InstructionFinetuneConfig + +Functions +========= + +.. autofunction:: load_instruction_finetuner \ No newline at end of file diff --git a/doc/source/reference/api/fairseq2.recipes/trainer.rst b/doc/source/reference/api/fairseq2.recipes/trainer.rst new file mode 100644 index 000000000..02164a626 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/trainer.rst @@ -0,0 +1,17 @@ +======================== +fairseq2.recipes.trainer +======================== + +.. module:: fairseq2.recipes.trainer + +.. autoclasstree:: fairseq2.recipes.trainer + :full: + :zoom: + + +Classes +======= + + +.. autoclass:: Trainer + :members: diff --git a/doc/source/reference/api/fairseq2.recipes/wav2vec2/index.rst b/doc/source/reference/api/fairseq2.recipes/wav2vec2/index.rst new file mode 100644 index 000000000..d2c211b83 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/wav2vec2/index.rst @@ -0,0 +1,17 @@ +========================= +fairseq2.recipes.wav2vec2 +========================= + +.. module:: fairseq2.recipes.wav2vec2 + + +Functions +========= + +.. autofunction:: _setup_wav2vec2_cli + + +.. toctree:: + :maxdepth: 1 + + train diff --git a/doc/source/reference/api/fairseq2.recipes/wav2vec2/train.rst b/doc/source/reference/api/fairseq2.recipes/wav2vec2/train.rst new file mode 100644 index 000000000..2207200a1 --- /dev/null +++ b/doc/source/reference/api/fairseq2.recipes/wav2vec2/train.rst @@ -0,0 +1,21 @@ +=============================== +fairseq2.recipes.wav2vec2.train +=============================== + +.. currentmodule:: fairseq2.recipes.wav2vec2.train + +.. autoclasstree:: fairseq2.recipes.wav2vec2 + :full: + :zoom: + +Classes +======= + +.. autoclass:: Wav2Vec2TrainConfig + +.. autoclass:: Wav2Vec2TrainUnit + +Functions +========= + +.. autofunction:: load_wav2vec2_trainer \ No newline at end of file diff --git a/doc/source/reference/api/fairseq2.rst b/doc/source/reference/api/fairseq2.rst new file mode 100644 index 000000000..d078d0e77 --- /dev/null +++ b/doc/source/reference/api/fairseq2.rst @@ -0,0 +1,16 @@ +======== +fairseq2 +======== + +.. module:: fairseq2 + +The root module contains library initialization functions. + +**Functions** + +* :class:`setup_fairseq2` + +Functions +========= + +.. autofunction:: setup_fairseq2 diff --git a/doc/source/reference/api/gang.rst b/doc/source/reference/api/gang.rst new file mode 100644 index 000000000..90a600df9 --- /dev/null +++ b/doc/source/reference/api/gang.rst @@ -0,0 +1,52 @@ +============= +fairseq2.gang +============= + +.. module:: fairseq2.gang + +This module provides the implementation of the ``Gang`` class and its related classes for managing collective operations in a distributed environment. + +Classes +------- + +.. autoclass:: Gang + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: AbstractGang + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: FakeGang + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ProcessGroupGang + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: GangError + :members: + :undoc-members: + +.. autoclass:: ReduceOperation + :members: + :undoc-members: + +Functions +--------- + +.. autofunction:: setup_default_gang +.. autofunction:: fake_gangs +.. autofunction:: setup_parallel_gangs +.. autofunction:: broadcast_flag +.. autofunction:: all_sum +.. autofunction:: get_world_size +.. autofunction:: get_rank +.. autofunction:: get_local_world_size +.. autofunction:: get_local_rank +.. autofunction:: is_torchrun diff --git a/doc/source/reference/api/index.rst b/doc/source/reference/api/index.rst new file mode 100644 index 000000000..88eadf8bc --- /dev/null +++ b/doc/source/reference/api/index.rst @@ -0,0 +1,15 @@ +.. _reference-api: + +:octicon:`file-code` API Reference +================================== + +.. toctree:: + :maxdepth: 1 + + fairseq2 + gang + fairseq2.data/index + fairseq2.datasets/index + fairseq2.nn/index + fairseq2.recipes/index + fairseq2.assets/index diff --git a/doc/source/reference/bibliography.rst b/doc/source/reference/bibliography.rst new file mode 100644 index 000000000..ac4ca9ce7 --- /dev/null +++ b/doc/source/reference/bibliography.rst @@ -0,0 +1,6 @@ +.. _reference-bibliography: + +:octicon:`book` Bibliography +============================ + +.. bibliography:: diff --git a/doc/source/tutorials/benchmarking.rst b/doc/source/tutorials/benchmarking.rst new file mode 100644 index 000000000..588db6aba --- /dev/null +++ b/doc/source/tutorials/benchmarking.rst @@ -0,0 +1,482 @@ +.. _tutorial-benchmarking: + +======================================= +:octicon:`clock` Efficient Benchmarking +======================================= + +.. dropdown:: What you will learn + :icon: multi-select + :animate: fade-in + + * How to benchmark language model training and inference + + * How to perform systematic hyperparameter sweeps + + * How to profile model performance using torch profiler + + * How to scale training to multiple nodes efficiently + + +.. dropdown:: Prerequisites + :icon: multi-select + :animate: fade-in + + * Get familiar with fairseq2 basics (:ref:`basics-overview`) + + * Ensure you have fairseq2 installed (:ref:`installation`) + + * Understand how to use built-in presets (:ref:`tutorial-presets`) + + * Familiarize yourself with recipes (:doc:`Recipe `) + + * Understand how to use CLI (:doc:`CLI `) + +.. image:: ../_static/img/tutorials/benchmark/2node_elapsed_time_relative.png + :align: center + :alt: 2 node elapsed time relative + :width: 600 + +Overview +-------- + +This tutorial will guide you through conducting systematic benchmarks using fairseq2. +We'll focus on practical examples using language models, covering: + +1. Training speed benchmarks +2. Multi-node scaling efficiency +3. Hyperparameter sweeps +4. Performance profiling + +.. note:: + + The examples will use LLaMA models, but the concepts apply to any model architecture. + + +Training Speed Benchmarks +------------------------- + +Let's start by benchmarking the training speed of different model configurations. + + +1. Environment Setup +^^^^^^^^^^^^^^^^^^^^ + +First, set up different virtual environments to test various PyTorch configurations. + +.. dropdown:: Example Environment Setup + :icon: multi-select + :animate: fade-in + + .. code-block:: bash + + # Create environments with different PyTorch versions + conda create -n fairseq2_pt22 python=3.10 + conda create -n fairseq2_pt24 python=3.10 + + # Install PyTorch 2.2 environment + conda activate fairseq2_pt22 + pip install torch==2.2.0 --index-url https://download.pytorch.org/whl/cu121 + pip install fairseq2 + + # Install PyTorch 2.4 environment + conda activate fairseq2_pt24 + pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121 + pip install fairseq2 + +.. note:: + + Follow the instructions in :ref:`installation` to install fairseq2 and PyTorch. + +2. Multi-Node Training +^^^^^^^^^^^^^^^^^^^^^^ + +fairseq2 CLI is designed to support distributed training across multiple nodes, and it facilitates the sweeping of hyperparameters across different environments. + +.. dropdown:: Example SLURM Script + :icon: code + :animate: fade-in + + .. code-block:: bash + + #!/bin/bash + #SBATCH --job-name=fairseq2_benchmark + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=8 + #SBATCH --gpus-per-node=8 + + # List of environments to test + envs=( + "fairseq2_pt22" + "fairseq2_pt24" + ) + + # Run benchmarks + for env_name in "${envs[@]}"; do + conda activate $env_name + for i in {0..1}; do # Two runs per environment + echo "Running $env_name run $i" + srun fairseq2 lm instruction_finetune \ + --preset llama3_1_70b_instruct \ + --config-file configs/benchmark.yaml \ + -- benchmark_outputs/${env_name}/run_${i} # output directory + done + conda deactivate + done + +.. dropdown:: Example ``benchmark.yaml`` + :icon: code + :animate: fade-in + + .. code-block:: yaml + + # Training config + max_num_steps: 1000 + batch_size: 4 + max_seq_len: 2048 + + # Distributed training + data_parallelism: "fsdp" + tensor_parallel_size: 8 + + # Optimization + optimizer: + lr: 2e-5 + weight_decay: 0.1 + + mixed_precision: "static" + dtype: "bfloat16" + +Hyperparameter Sweeps +--------------------- + +fairseq2 provides powerful sweep functionality with its :class:`fairseq2.recipes.utils.sweep_tagger.SweepTagger`. +It helps ensure: + +1. Consistent directory structure across nodes +2. Reproducible experiments +3. Easy comparison of different configurations + +For example, when running multi-node training: + +.. code-block:: bash + + #!/bin/bash + #SBATCH --job-name=mt_sweep + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=8 + #SBATCH --gpus-per-node=8 + + # Language pairs to sweep + lang_pairs=( + "eng-fra" + "eng-deu" + "eng-spa" + ) + + # Run MT sweeps + for pair in "${lang_pairs[@]}"; do + src_lang=${pair%-*} + tgt_lang=${pair#*-} + + # fairseq2 CLI will automatically use SweepTagger to create + # a unique directory based on the config + srun fairseq2 mt train \ + --preset nllb_600m \ + --config-file configs/mt.yaml \ + --config source_lang=$src_lang target_lang=$tgt_lang \ + -- sweep_outputs/ # Base output directory + +The fairseq2 CLI will: + +1. Parse the config file and command line overrides +2. Use :class:`fairseq2.recipes.utils.sweep_tagger.SweepTagger` to generate a unique tag based on sweep keys +3. Create a subdirectory using this tag under the base output directory +4. Ensure all nodes write to the same directory structure +5. If ``fmt`` is provided, it will be used to generate the tag in a customizable format + +.. note:: + + Use ``--no-sweep-dir`` when you want to disable automatic sweep directory creation. This is useful when: + + - Running quick tests/debugging + - Using custom directory structures + +Different recipes support different sweep keys. +The following examples will show how to configure sweep tags for different recipes. + +1. Language Model Sweeps +^^^^^^^^^^^^^^^^^^^^^^^^ + +For language models, we have two main finetuning approaches. + +.. dropdown:: Instruction Finetuning (SFT) + :icon: multi-select + :animate: fade-in + + .. code-block:: python + + from fairseq2.recipes.lm.instruction_finetune import ( + InstructionFinetuneConfig, + instruction_finetune_presets + ) + from fairseq2.recipes.utils.sweep_tagger import SweepTagger + + # Configure LM sweep + sweep_keys = { + "batch_size", + "max_seq_len", + "dtype", + "tensor_parallel_size" + } + + sweep_tagger = SweepTagger(world_size=8, allowed_keys=sweep_keys) + + # Example instruction finetuning config + config = { + "max_num_steps": 1000, + "batch_size": 4, + "max_seq_len": 2048, + "dtype": "bfloat16" + } + + # Generate unique tag for this config + tag = sweep_tagger.generate( + "llama3_1_70b_instruct", + config, + fmt="ps_{preset}.ws_{world_size}.{batch_size}_{max_seq_len}_{dtype}", + ) + output_dir = Path(f"sweep_outputs/{tag}") + +.. dropdown:: Preference Finetuning (DPO) + :icon: multi-select + :animate: fade-in + + .. code-block:: python + + from fairseq2.recipes.lm.preference_finetune.dpo import ( + DpoConfig, + create_dpo_unit + ) + from fairseq2.recipes.utils.sweep_tagger import SweepTagger + + # Configure DPO sweep + sweep_keys = { + "batch_size", + "max_seq_len", + "beta", # DPO-specific + "nll_scale", # DPO-specific + "reference_tensor_parallel_size", # DPO-specific + "length_normalization" # DPO-specific + } + + sweep_tagger = SweepTagger(world_size=8, sweep_keys=sweep_keys) + + # Example DPO config + config = { + "max_num_steps": 1000, + "batch_size": 4, + "max_seq_len": 2048, + "beta": 0.1, + "nll_scale": 0.0, + "reference_model": "llama3_1_8b_instruct", + "reference_tensor_parallel_size": 1, + "length_normalization": False + } + + # Generate unique tag for this config + tag = sweep_tagger.generate("llama3_1_8b_dpo", config) + output_dir = Path(f"sweep_outputs/{tag}") + + Example SLURM script for running DPO sweeps: + + .. code-block:: bash + + #!/bin/bash + #SBATCH --job-name=dpo_sweep + #SBATCH --nodes=4 + #SBATCH --ntasks-per-node=8 + #SBATCH --gpus-per-node=8 + + # List of beta values to sweep + betas=(0.1 0.2 0.5) + + # Run DPO sweeps + for beta in "${betas[@]}"; do + srun fairseq2 lm preference_finetune \ + --preset llama3_1_8b_dpo \ + --config-file configs/dpo.yaml \ + --config "beta=$beta" + -- sweep_outputs/ + done + +2. Machine Translation Sweeps +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +MT recipes include additional sweep keys specific to translation tasks. + +.. dropdown:: Example MT sweep + :icon: code + :animate: fade-in + + .. code-block:: python + + from fairseq2.recipes.mt.train import load_mt_trainer, mt_train_presets + from fairseq2.recipes.utils.sweep_tagger import SweepTagger + + # Configure MT sweep + sweep_keys = { + "lr", + "weight_decay", + "source_lang", # MT-specific + "target_lang", # MT-specific + "max_seq_len", + "batch_size" + } + + sweep_tagger = SweepTagger(world_size=8, sweep_keys=sweep_keys) + + # Example MT config + config = { + "source_lang": "eng", + "target_lang": "fra", + "optimizer_config": { + "lr": 2e-5, + "weight_decay": 0.1 + } + } + + # Generate unique tag for this config + tag = sweep_tagger.generate("nllb_600m", config) + output_dir = Path(f"sweep_outputs/{tag}") + +3. wav2vec2 Sweeps +^^^^^^^^^^^^^^^^^^ + +Speech models also have their own set of sweep parameters: + +.. dropdown:: Example wav2vec2 sweep + :icon: code + :animate: fade-in + + .. code-block:: python + + from fairseq2.models.wav2vec2.asr import wav2vec2_asr_archs + from fairseq2.recipes.utils.sweep_tagger import SweepTagger + + # wav2vec2-specific sweep keys + sweep_keys = { + "freeze_encoder_for_n_steps", + "max_audio_len", + "min_audio_len", + "normalize_audio", + } + + sweep_tagger = SweepTagger(world_size=8, allowed_keys=sweep_keys) + + # Example wav2vec2 config + config = { + "freeze_encoder_for_n_steps": 1_000, + "max_audio_len": 100_000, + "min_audio_len": 1_000, + "normalize_audio": True + } + + # Generate unique tag for this config + tag = sweep_tagger.generate( + "wav2vec2_base", + config, + fmt="ps_{preset}.ws_{world_size}.mal_{max_audio_len}.minal_{min_audio_len}.norm_{normalize_audio}", + ) + + output_dir = Path(f"sweep_outputs/{tag}") + +Performance Profiling +--------------------- + +fairseq2 uses PyTorch's profiler to help analyze performance bottlenecks. +The profiler results will be saved to TensorBoard format in the output directory. +It allows you to visualize the performance of your model in detail. +It is also a useful tool for gathering performance metrics for hyperparameter sweeps. + +.. dropdown:: Analysis of Profiler Results + :icon: multi-select + :animate: fade-in + + .. image:: ../_static/img/tutorials/benchmark/2node_eps_absolute.png + :align: center + :alt: Profiler Results + :width: 600 + + + To visualize the results, start Tensorboard at the output directory: + + .. code-block:: bash + + # Start Tensorboard + tensorboard --logdir ./profile_outputs/tb/ + + Access the results in your browser at http://localhost:6006. + + You can also plot the results in a customized way for your own analysis: + + .. code-block:: python + + from tensorboard.backend.event_processing import event_accumulator + import pandas as pd + import seaborn as sns + import matplotlib.pyplot as plt + + def parse_tensorboard(path, scalars): + ea = event_accumulator.EventAccumulator( + path, + size_guidance={event_accumulator.SCALARS: 0}, + ) + ea.Reload() + return {k: pd.DataFrame(ea.Scalars(k)) for k in scalars} + + def analyze_performance(log_dir): + # Parse metrics + metrics = parse_tensorboard(log_dir, ["Wall Time"]) # or "Elements per Second", "Elapsed Time" + + # Calculate statistics + wall_time = metrics["Wall Time"] + steps_per_second = len(wall_time) / wall_time["value"].sum() + + # Visualize + plt.figure(figsize=(10, 6)) + sns.lineplot(data=wall_time, x="step", y="value") + plt.title("Training Wall Time per Step") + plt.show() + + return steps_per_second + +Best Practices +-------------- + +1. **Systematic Benchmarking** + + - Always benchmark with fixed seeds for reproducibility + - Test multiple batch sizes and sequence lengths + - Measure both training and validation performance + - Record memory usage and throughput metrics + +2. **Distributed Training** + + - Start with single-node tests before scaling to multiple nodes + - Monitor communication overhead between nodes + - Use FSDP for large models that don't fit in GPU memory + - Experiment with different tensor parallel sizes + +3. **Performance Optimization** + + - Enable mixed precision training when possible + - Tune gradient accumulation steps + - Profile to identify bottlenecks + - Monitor GPU utilization and memory usage + +See Also +-------- + +- :doc:`Recipe ` +- :doc:`CLI ` +- :doc:`Presets ` diff --git a/doc/source/tutorials/end_to_end_fine_tuning.rst b/doc/source/tutorials/end_to_end_fine_tuning.rst new file mode 100644 index 000000000..c4d2d3abc --- /dev/null +++ b/doc/source/tutorials/end_to_end_fine_tuning.rst @@ -0,0 +1,427 @@ +.. _tutorial-end-to-end-fine-tuning: + +==================================================== +:octicon:`comment-discussion` End-to-End Fine-Tuning +==================================================== + +.. dropdown:: What you will learn + :icon: multi-select + :animate: fade-in + + * How to customize your assets (`e.g.` models, datasets, tokenizers) + + * How to run instruction fine-tuning recipe + + * How to use fairseq2 to generate (inference) + + * How to convert fairseq2 ckpt to huggingface ckpt for accelerated vllm inference + + * How to run fairseq2 with multiple nodes + +.. dropdown:: Prerequisites + :icon: multi-select + :animate: fade-in + + * Get familiar with fairseq2 basics (:ref:`basics-overview`) + + * Ensure you have fairseq2 installed (:ref:`installation`) + + +Overview +-------- + +#. **Prepare** + + * Download the `LLaMA3.2 1B model`_ from HuggingFace + + * Download the `gsm8k data`_ prepared for this tutorial + +#. **Fine-Tune** + + * One simple command to run the instruction fine-tuning recipe + + * Accelerate the training with multiple nodes + +#. **Generate** + + * One simple command to generate from the finetuned model + + * Convert fairseq2 model ckpt to hf ckpt for accelerated vllm inference + +#. **Go Beyond** + + * Use fairseq2 to accelerate your research + + +Prepare +------- + + +Model +^^^^^ + +Follow the `HuggingFace Models Tutorial`_ to download the `LLaMA3.2 1B model`_, which can be run on volta32gb GPUs. +Once you have the model in your local path, (`e.g.`` ``/models/Llama-3.2-1B/original/consolidated.00.pth``), +you need to register the model in a YAML card so that fairseq2 will know from where to pull the model +(read more about :ref:`basics-assets`). To do that: + +* Create a YAML file (e.g. ``my_llama3_2_1b.yaml``) with the following content: + +.. code-block:: yaml + + name: llama3_2_1b@user + checkpoint: "/models/Llama-3.2-1B/original/consolidated.00.pth" + + --- + + name: llama3@user + tokenizer: "/models/Llama-3.2-1B/original/tokenizer.model" + +.. tip:: + + The ``@user`` specifies this is your special environment. This can also be extended to help resolve different domain name for your clusters + + +* Save the file in one of the following locations: + + * `Option 1`: Place it in the default fairseq2 asset directory + + * ``mkdir -p ~/.config/fairseq2/assets`` + + * ``mv my_llama3_2_1b.yaml ~/.config/fairseq2/assets/`` + + * `Option 2`: Specify a custom directory and point ``FAIRSEQ2_USER_ASSET_DIR`` to it + + * ``export FAIRSEQ2_USER_ASSET_DIR=/path/to/custom/asset/directory`` + + * ``mv my_llama3_2_1b.yaml /path/to/custom/asset/directory/`` + +Dataset +^^^^^^^ + +Follow the `HuggingFace Datasets Tutorial`_ to download the `gsm8k data`_, (formatted with fairseq2 flavor) to your local path (`e.g.` ``/datasets/facebook/fairseq2-lm-gsm8k/``). +We will use the ``sft/train.jsonl`` to fine-tune the model and use the ``test/test.jsonl`` for evaluation. + + +Fine-Tune +--------- + +One-Liner +^^^^^^^^^ + +Running the fine-tuning recipe is as simple as: + +.. code-block:: bash + + fairseq2 lm instruction_finetune $OUTPUT_DIR --config \ + dataset=/datasets/facebook/fairseq2-lm-gsm8k/sft \ + model=llama3_2_1b \ + max_num_tokens=4096 \ + dtype=float16 \ + max_num_steps=1000 \ + max_num_data_epochs=20 \ + checkpoint_every_n_steps=1000 + + +.. dropdown:: You can also put the configuration in a YAML file + :icon: code + :animate: fade-in + + .. code-block:: yaml + + # /configs/example.yaml + dataset: /datasets/facebook/fairseq2-lm-gsm8k/sft + model: llama3_2_1b + max_num_tokens: 4096 + max_seq_len: 4096 + max_num_steps: 1000 + max_num_data_epochs: 20 + checkpoint_every_n_steps: 1000 + keep_last_n_checkpoints: 1 + keep_last_n_models: 1 + publish_metrics_every_n_steps: 5 + dtype: float16 # volta32gb gpus do not support bfloat16 + + Then run: + + .. code-block:: bash + + CONFIG_FILE=/configs/example.yaml + fairseq2 lm instruction_finetune $OUTPUT_DIR --config-file $CONFIG_FILE + + For more details about the recipe configuration, please refer to :ref:`basics-recipe`. + +Iterative Training +^^^^^^^^^^^^^^^^^^ + +Sometimes you may want to continue fine-tuning from a previously trained checkpoint, either to: + +- Resume interrupted training +- Fine-tune on additional data +- Perform iterative fine-tuning with different hyperparameters + +fairseq2 provides a clean way to handle this through the checkpoint system (learn more about :ref:`basics-ckpt-management`): + +.. code-block:: bash + + fairseq2 lm instruction_finetune $OUTPUT_DIR --config \ + resume_checkpoint_dir=/path/to/checkpoint \ + model="last_checkpoint" \ # this will pick up the last checkpoint + dataset=/path/to/data + +.. dropdown:: To pick up a specific checkpoint + :icon: code + :animate: fade-in + + .. code-block:: bash + + CKPT_PATH="/checkpoint/user/experiments/run_0/checkpoints/step_1000" # this is the path to the checkpoint + CKPT_DIR=$(dirname "$CKPT_PATH") # e.g., /checkpoint/user/experiments/run_0/checkpoints + CKPT="checkpoint_$(basename "$CKPT_DIR")" # e.g., checkpoint_step_1000 + + fairseq2 lm instruction_finetune $OUTPUT_DIR --config \ + resume_checkpoint_dir=$CKPT_DIR \ + model=$CKPT \ # Must match the checkpoint step + dataset=/path/to/new/data \ + max_num_tokens=4096 \ + dtype=float16 + + .. note:: + + If you want to pick a specific checkpoint instead of the last checkpoint, the ``model`` parameter must be set to ``checkpoint_step_X`` where X matches the step number of the checkpoint you want to load. + +.. dropdown:: A more detailed example + :icon: code + :animate: fade-in + + For iterative fine-tuning across different datasets or with different hyperparameters: + + .. code-block:: yaml + + # config.yaml + # First stage - train on dataset A + dataset: /path/to/dataset_A + model: llama3_2_1b + max_num_steps: 1000 + learning_rate: 1e-5 + # ... other config + + Then run the following commands in bash: + + .. code-block:: bash + + # First stage + fairseq2 lm instruction_finetune run1_output --config-file config.yaml + + # Second stage - continue from first stage checkpoint + fairseq2 lm instruction_finetune run2_output --config \ + resume_checkpoint_dir=run1_output/checkpoints \ + model=checkpoint_step_1000 \ + dataset=/path/to/dataset_B \ + learning_rate=5e-6 # Lower learning rate for second stage + max_num_steps=500 + + .. tip:: + + When doing iterative fine-tuning: + + - Generally use a lower learning rate in later stages + - Consider reducing the number of steps for later stages + - You may want to adjust the validation frequency + - Make sure to track metrics to compare performance across stages + +Multi-Node +^^^^^^^^^^ + +To help accelerate the training, fairseq2 is able to automatically detect multi-node setup. + +- `Option 1`: Slurm + + .. code-block:: bash + + srun --nodes=2 --ntasks-per-node=8 \ + fairseq2 lm instruction_finetune $OUTPUT_DIR \ + ... + +- `Option 2`: Torchrun + + .. code-block:: bash + + torchrun --standalone --nproc-per-node 8 --no-python \ + fairseq2 lm instruction_finetune $OUTPUT_DIR \ + ... + +Generate +-------- + +Once we have finished the training, we can find in the ``$OUTPUT_DIR`` the model checkpoints in ``$OUTPUT_DIR/checkpoints``. With that, we can now generate over the test dataset! + + +Native Support +^^^^^^^^^^^^^^ + +fairseq2 natively supports inference: + +.. code-block:: bash + + CKPT_PATH="/checkpoint/$USER/experiments/$EXPERIMENT_NAME/checkpoints/step_1000" + CKPT_DIR=$(dirname "$CKPT_PATH") + CKPT="checkpoint_$(basename "$CKPT_DIR")" # e.g., checkpoint_step_1000 + SAVE_DIR="$CKPT_DIR/generation" + DATASET="/datasets/facebook/fairseq2-lm-gsm8k/test/test.jsonl" + + fairseq2 lm generate $SAVE_DIR --no-sweep-dir --config \ + checkpoint_dir=$CKPT_DIR \ + model=$CKPT \ + generator_config.temperature=0.1 \ + dataset=$DATASET + + +VLLM Support +^^^^^^^^^^^^ + + +To accelerate the inference process, we can convert fairseq2 checkpoints to HuggingFace checkpoints, which can be deployed with VLLM. This takes 2 steps: + +**Step 1: Convert fairseq2 checkpoint to XLFormer checkpoint** + +The first step is to use the fairseq2 command-line (:ref:`basics-cli`) tool to convert the fairseq2 checkpoint to an XLF checkpoint. The command structure is as follows: + +.. code-block:: bash + + fairseq2 llama convert_checkpoint --model + + +* ````: Specify the architecture of the model -- `e.g.`, ``llama3`` (see :mod:`fairseq2.models.llama`) + +* ````: Path to the directory containing the Fairseq2 checkpoint + +* ````: Path where the XLF checkpoint will be saved + + +.. note:: + + Architecture ``--arch`` must exist and be defined in `e.g.` :meth:`fairseq2.models.llama.archs.register_archs`. + + +**Step 2: Convert XLFormer checkpoint to HF checkpoint** + +After obtaining the XLFormer checkpoint, the next step is to convert it to the Hugging Face format. Please refer to the official `HF script`_. + + +**Step 3: Deploy with VLLM** + +.. code-block:: python + + from vllm import LLM + + llm = LLM(model=) # path of your model + output = llm.generate("Hello, my name is") + print(output) + +Please refer to the `VLLM documentation`_ for more details. + +Check the Accuracy +^^^^^^^^^^^^^^^^^^ + +Once you generated the output, it is relatively trivial to compute the accuracy. Overall, you just need to: + +* Load the generated dataset + +* Load the original test dataset as ground truth + +* Compare and count the number of correct items + +.. dropdown:: Some example utils functions + :icon: code + :animate: fade-in + + .. code-block:: python + + import re + + ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") + INVALID_ANS = "[invalid]" + + + def extract_answer(completion: str) -> str: + """ + Extract the answer from the completion. + + :param completion: The completion. + :return: The answer. + """ + global ANS_RE, INVALID_ANS + match = ANS_RE.search(completion) + if match: + match_str = match.group(1).strip() + match_str = match_str.replace(",", "") + return match_str + else: + return INVALID_ANS + + + def is_correct(model_completion: str, gt_example: str) -> bool: + """ + Check if the model completion is correct. + + :param model_completion: The model completion. + :param gt_example: The ground truth example. + :return: True if the model completion is correct, False otherwise. + """ + gt_answer = extract_answer(gt_example) + assert gt_answer != INVALID_ANS + return extract_answer(model_completion) == gt_answer + + +Go Beyond +--------- + + +That's pretty much it to get you started. But you can do a lot more. fairseq2 is a powerful tool to help you accelerate and scale up your research. It allows: + +* Experiment with different hyper-parameter configurations; + +.. image:: /_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_elements_per_second.png + :width: 580px + :align: center + :alt: Elements per Second + +* Compare performance across various datasets or model architectures; + +.. image:: /_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_accuracy.png + :width: 580px + :align: center + :alt: Model Comparison + +* Profile resource usage and optimize training workflows; + +.. image:: /_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_trace.png + :width: 580px + :align: center + :alt: Tracing + +* Connect to your WanDB and monitor your experiments in real-time; + +.. image:: /_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_wandb.png + :width: 580px + :align: center + :alt: WandB + +Now, up for you to discover!!! + +See Also +-------- + +- :doc:`Design Philosophy ` +- :doc:`Recipe ` +- :doc:`CLI ` +- :doc:`Assets ` + + +.. _LLaMA3.2 1B model: https://huggingface.co/meta-llama/Llama-3.2-1B/tree/main +.. _gsm8k data: https://huggingface.co/datasets/facebook/fairseq2-lm-gsm8k +.. _HuggingFace Models Tutorial: https://huggingface.co/docs/hub/en/models-downloading +.. _HuggingFace Datasets Tutorial: https://huggingface.co/docs/hub/en/datasets-downloading +.. _HF script: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py +.. _VLLM documentation: https://vllm.readthedocs.io/en/latest/ diff --git a/doc/source/tutorials/models.rst b/doc/source/tutorials/models.rst new file mode 100644 index 000000000..49cb2c2d1 --- /dev/null +++ b/doc/source/tutorials/models.rst @@ -0,0 +1,498 @@ +.. _tutorial-models: + +:octicon:`ruby` Add Your Own Model +================================== + + +.. dropdown:: What you will learn + :icon: multi-select + :animate: fade-in + + * How to configure a model + + * How to register a model architecture + + * How to use model factories to create models + + * How to use model loaders to load models + +.. dropdown:: Prerequisites + :icon: multi-select + :animate: fade-in + + * Get familiar with fairseq2 basics (:ref:`basics-overview`) + + * Ensure you have fairseq2 installed (:ref:`installation`) + + * Get familiar with presets (:ref:`tutorial-presets`) + +Overview +-------- + +The model configuration and loading system in fairseq2 consists of several key components: + + +#. **Model Config** + + * Defines the architecture and hyperparameters of a model (`e.g. number of layers, hidden size, learning rate, etc.`) + +#. **Architecture Registry** + + * Stores predefined model architectures (`e.g. base, large, small, etc.`) + +#. **Model Factory** + + * Creates model instances from configs + +#. **Model Loader** + + * Handles model instantiation, checkpoint loading and format conversion (`e.g. loading from fairseq2 checkpoint, converting from HF checkpoint, etc.`) + + +Directory Layout +---------------- + +The directory structure for a typical fairseq2 model looks like this: + +.. code-block:: bash + + fairseq2/models/ + ├── your_model/ + │ ├── __init__.py + │ ├── archs.py # Defines model architectures + │ ├── factory.py # Contains model factory and config classes + │ ├── loader.py # Handles model loading and checkpoint conversion + │ └── model.py # Actual model implementation + +.. note:: + The actual layout might vary depending on your implementation. + +Step-by-Step Guide +------------------ + +1. Define Model Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, create a configuration class in ``factory.py``: + +.. code-block:: python + + from dataclasses import dataclass + from fairseq2.typing import DataType + from fairseq2.data import VocabularyInfo + + @dataclass(kw_only=True) + class YourModelConfig: + """Configuration for YourModel.""" + # Basic model parameters + model_dim: int = 512 + """The dimensionality of the model.""" + + num_layers: int = 6 + """The number of layers in the model.""" + + num_heads: int = 8 + """The number of attention heads in the model.""" + + ... + +In the same file, create a registry for the model config: + +.. code-block:: python + + your_model_config_registry = ConfigRegistry[YourModelConfig]() + + your_model_arch = your_model_config_registry.decorator + +This ``your_model_arch`` is a decorator that can be later used to register model architectures. + + +2. Register Model Architectures +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Create an architecture registry and define standard architectures in ``archs.py``: + +.. code-block:: python + + from fairseq2.models.your_model.factory import your_model_arch + + @your_model_arch("base") + def _base() -> YourModelConfig: + """Base architecture.""" + return YourModelConfig() + + @your_model_arch("large") + def _large() -> YourModelConfig: + """Large architecture.""" + config = YourModelConfig() + config.model_dim = 1024 + config.num_layers = 12 + config.num_heads = 16 + return config + +.. note:: + Keep the architecture names descriptive and simple. Document differences between architectures. + + +.. dropdown:: Some real-world examples + :icon: code + :animate: fade-in + + * **Base Transformer Architecture** + + The base Transformer model provides a foundation that other models can build upon: + + .. code-block:: python + + # In transformer/archs.py + from fairseq2.models.transformer.factory import TransformerConfig, transformer_arch + + @transformer_arch("base") + def _base() -> TransformerConfig: + """Base architecture with default parameters.""" + return TransformerConfig() + + @transformer_arch("big") + def _big() -> TransformerConfig: + """Larger architecture with modified parameters.""" + config = TransformerConfig() + config.model_dim = 1024 + config.num_encoder_attn_heads = 16 + config.num_decoder_attn_heads = 16 + config.ffn_inner_dim = 4096 + config.dropout_p = 0.3 + return config + + + * **NLLB (No Language Left Behind)** + + NLLB extends the base Transformer architecture with specific configurations for multilingual translation: + + .. code-block:: python + + # In nllb/archs.py + @transformer_arch("nllb_dense_600m") + def _dense_600m() -> TransformerConfig: + config = _dense_1b() # Inherits from larger architecture + + # Modify for smaller model + config.num_encoder_layers = 12 + config.num_decoder_layers = 12 + config.ffn_inner_dim = 1024 * 4 + + return config + + @transformer_arch("nllb_dense_1b") + def _dense_1b() -> TransformerConfig: + config = transformer_archs.get("base") # Start from base transformer + + # Customize for NLLB + config.model_dim = 1024 + config.vocab_info = VocabularyInfo( + size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 + ) + config.num_encoder_layers = 24 + config.num_decoder_layers = 24 + config.num_encoder_attn_heads = 16 + config.num_decoder_attn_heads = 16 + config.ffn_inner_dim = 1024 * 8 + config.norm_order = TransformerNormOrder.PRE + + return config + + + * **LLaMA Architecture** + + LLaMA introduces its own configuration class with specific parameters for large language models: + + .. code-block:: python + + # In llama/archs.py + @llama_arch("7b") + def _7b() -> LLaMAConfig: + """7B parameter model.""" + return LLaMAConfig() # Uses default parameters + + @llama_arch("13b") + def _13b() -> LLaMAConfig: + """13B parameter model.""" + config = _7b() + config.model_dim = 5120 + config.num_attn_heads = 40 + config.num_key_value_heads = 40 + config.ffn_inner_dim = 5120 * 4 + return config + + @llama_arch("llama2_70b") + def _llama2_70b() -> LLaMAConfig: + """LLaMA 2 70B parameter model.""" + config = _65b() + config.max_seq_len = 4096 + config.num_key_value_heads = 8 + config.ffn_inner_dim = int(8192 * 4 * 1.3) # See A.2.1 in LLaMA 2 + config.ffn_inner_dim_to_multiple = 4096 + return config + + +3. Create Model Factory +^^^^^^^^^^^^^^^^^^^^^^^ + +Implement a factory function in ``factory.py`` that creates model instances: + +.. code-block:: python + + def create_your_model(config: YourModelConfig) -> YourModel: + """Create a model instance from config.""" + model = YourModel( + model_dim=config.model_dim, + num_layers=config.num_layers, + num_heads=config.num_heads, + dropout_p=config.dropout_p, + vocab_info=config.vocab_info, + ) + + # Convert to specified dtype + model.to(dtype=config.dtype) + + return model + + +.. dropdown:: Some real-world examples + :icon: code + :animate: fade-in + + * **LLaMA Model Factory** + + We will use the ``fairseq2.models.llama.factory.create_llama_model`` function as an example. + + The ``create_llama_model`` function serves as a factory method for instantiating a LLaMA model. + It encapsulates the process of building a model with the ``LLaMABuilder`` class, which constructs various components of the model based on the provided configuration. + This design pattern allows for a clean separation of model creation logic, making it easier for users to customize and extend the model architecture. + + .. code-block:: python + + # In llama/factory.py + class LLaMABuilder: + ... + + def build_model(self) -> TransformerDecoderModel: + """Build a model.""" + decoder_frontend = self.build_decoder_frontend() + + decoder = self.build_decoder() + + final_proj = Linear(...) + + model = TransformerDecoderModel( + decoder_frontend, + decoder, + final_proj, + ... + ) + + model.set_family(LLAMA_FAMILY) + + return model + + + def create_llama_model( + config: LLaMAConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> TransformerDecoderModel: + """Create a LLaMA model.""" + return LLaMABuilder(config, device=device, dtype=dtype).build_model() + + + model_factories.register(LLAMA_FAMILY, create_llama_model, LLaMAConfig, llama_archs) + + `create_llama_model` instantiates your builder class and call the `build_model` method that actually creates the model as a `TransformerDecoderModel`. + Don't forget to register your model with the fairseq2 model factories so that it can be easily instantiated later. + +4. Set Up Model Loader +^^^^^^^^^^^^^^^^^^^^^^ + +Create a loader in ``loader.py`` that handles model instantiation and checkpoint loading: + +.. code-block:: python + + from fairseq2.models.config_loader import StandardModelConfigLoader + from fairseq2.models.loader import StandardModelLoader, load_model + + # Create config loader + load_your_model_config = StandardModelConfigLoader( + YOUR_MODEL_FAMILY, + YourModelConfig, + your_model_archs + ) + + def convert_your_model_checkpoint( + checkpoint: dict[str, Any], config: YourModelConfig + ) -> dict[str, Any]: + """Convert external checkpoints to fairseq2 format.""" + # Add checkpoint conversion logic here + return {"model": checkpoint} + + # Create model loader + load_your_model = StandardModelLoader( + config_loader=load_your_model_config, + factory=create_your_model, + checkpoint_converter=convert_your_model_checkpoint, + ) + + # Register loader with global registry + load_model.register(YOUR_MODEL_FAMILY, load_your_model) + +.. dropdown:: Some real-world examples on ckpt conversion + :icon: code + :animate: fade-in + + The `convert_your_model_checkpoint` function is a checkpoint converter that converts external checkpoints to fairseq2 format. + For example, in Mistral, the checkpoint format is different from fairseq2's. + + .. code-block:: python + + # In mistral/loader.py + def convert_mistral_checkpoint( + checkpoint: dict[str, Any], config: MistralConfig + ) -> dict[str, Any]: + """Convert Mistral checkpoint to fairseq2 format.""" + if "model" in checkpoint: # Already in fairseq2 format + return checkpoint + + # Map parameter names from Mistral to fairseq2 format + key_map = { + r"^layers\.([0-9]+)\.attention\.wq\.": r"decoder.layers.\1.self_attn.q_proj.", + r"^layers\.([0-9]+)\.attention\.wk\.": r"decoder.layers.\1.self_attn.k_proj.", + r"^layers\.([0-9]+)\.attention\.wv\.": r"decoder.layers.\1.self_attn.v_proj.", + # ... more mappings + } + + checkpoint = convert_model_state_dict(checkpoint, key_map) + return {"model": checkpoint} + + Overall, to support loading from different checkpoint formats: + + 1. Modify the checkpoint converter function + 2. Add mapping logic for different parameter names + 3. Handle any necessary tensor transformations + +.. dropdown:: Advanced topic: Sharding + :icon: code + :animate: fade-in + + The ``sharder`` argument in ``StandardModelLoader`` is a function that shards the model, which is useful for distributed training. + This is natively supported by fairseq2, so you don't need to implement it yourself. + For example, in LLaMA, the ``shard_llama_model`` function shards the model across multiple devices: + + .. code-block:: python + + # In llama/loader.py + from fairseq2.models.transformer import shard_transformer_decoder_model + from fairseq2.models.loader import StandardModelLoader + + def shard_llama_model( + model: TransformerDecoderModel, config: LLaMAConfig, gangs: Mapping[str, Gang] + ) -> None: + gang = gangs["tp"] # tensor parallel + + shard_embed_dim = config.max_seq_len < 8192 # LLaMA 1 or 2 + + shard_transformer_decoder_model(model, gang, shard_embed_dim=shard_embed_dim) + + + load_llama_model = StandardModelLoader( + ... + sharder=shard_llama_model, + ) + +5. Using with Trainer +^^^^^^^^^^^^^^^^^^^^^ + +The model can be used with the fairseq2 trainer: + +.. code-block:: python + + from fairseq2.models.loader import load_model + from fairseq2.recipes.trainer import Trainer, TrainUnit + from fairseq2.recipes.utils.asset import retrieve_asset_card + + model_card = retrieve_asset_card("llama3_2_1b") + + # Load model + model = load_model( + model_card, + device=Device("cpu") + ) + + # Create training unit + class YourTrainUnit(AbstractTrainUnit[SequenceBatch]): + def __init__(self, model: YourModel) -> None: + super().__init__(model) + self._metric_bag = MetricBag() + + def __call__(self, batch: YourBatchType) -> tuple[Tensor, int]: + loss = self._model(**batch) + return loss, batch.num_targets + + # Set up trainer + trainer = Trainer( + unit=YourTrainUnit(model), + data_reader=your_data_reader, + optimizer=your_optimizer, + # ... other trainer parameters + ) + + # Run training + trainer() + +For a real-world example, see the :mod:`fairseq2.recipes.lm` recipe. + + +Best Practices +-------------- + +#. **Configuration**: + + * Provide sensible defaults for all parameters + * Document each config parameter + +#. **Architecture Registry**: + + * Use descriptive names for architectures + * Keep base architectures simple + * Document differences between architectures + +#. **Model Loading**: + + * Handle checkpoint format differences gracefully + * Validate config parameters before model creation + * Provide clear error messages for invalid configs + +#. **Training Integration**: + + * Create a dedicated training unit for your model + * Implement proper metric tracking + * Handle device placement and dtype conversion + +Common Pitfalls +--------------- + +#. **Checkpoint Compatibility**: + + * Ensure checkpoint conversion handles all parameter mappings + * Verify tensor shapes and dtypes match + * Handle missing or extra parameters gracefully + +#. **Configuration Issues**: + + * Validate all config parameters before use + * Handle interdependent parameters correctly + * Document any parameter constraints + +#. **Training Problems**: + + * Ensure proper device placement + * Handle batch processing efficiently + * Implement correct loss computation diff --git a/doc/source/tutorials/monitor_your_experiments.rst b/doc/source/tutorials/monitor_your_experiments.rst new file mode 100644 index 000000000..a240f1e5d --- /dev/null +++ b/doc/source/tutorials/monitor_your_experiments.rst @@ -0,0 +1,101 @@ +.. _tutorials-monitor-your-experiments: + +:octicon:`codescan-checkmark` Monitor Your Experiments +====================================================== + + +.. dropdown:: What you will learn + :icon: multi-select + :animate: fade-in + + * How to monitor your experiments using Tensorboard + + * How to monitor your experiments using WanDB + +.. dropdown:: Prerequisites + :icon: multi-select + :animate: fade-in + + * Get familiar with fairseq2 basics (:ref:`basics-overview`) + + * Ensure you have fairseq2 installed (:ref:`installation`) + +TensorBoard +----------- + + +.. image:: /_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_trace.png + :width: 580px + :align: center + :alt: TensorBoard + + +fairseq2 saves checkpoints and tensorboard events to the defined ``$OUTPUT_DIR``, which allows you to investigate into the details in your jobs. + +.. code-block:: bash + + # run tensorboard at your ckpt path + tensorboard --logdir $CHECKPOINT_PATH + + # example + tensorboard --logdir /checkpoint/$USER/outputs/ps_llama3_1_instruct.ws_16.a73dad52/tb/train + +If you ran your experiment on your server, you probably need to port forward the tensorboard service to your local machine: + +.. code-block:: bash + + ssh -L 6006:localhost:6006 $USER@$SERVER_NAME + +Then you can view the tensorboard service in your browser `http://localhost:6006 `__. + + +WanDB +------ + + +.. image:: /_static/img/tutorials/end_to_end_fine_tuning/tutorial_example_wandb.png + :width: 580px + :align: center + :alt: WandB + +fairseq2 natively support WanDB (Weights & Biases) - a powerful tool for monitoring and managing machine learning experiments. +WanDB provides a centralized platform to track, compare, and analyze the performance of different models, making it easier to identify trends, optimize hyperparameters, and reproduce results. +Follow the `quick start guide `__ to initialize it in your environment. + +What you need to do is simply add the following line in your config YAML file: + +.. code-block:: yaml + + wandb_project: + +Then run your recipe with ``fairseq2 ... --config-file .yaml``. + +Or you can directly specify with ``fairseq2 ... --config wandb_project=``. + +Then you can open up your WanDB Portal and check the results in real-time. + + +.. dropdown:: A step-by-step example + :icon: code + :animate: fade-in + + .. code-block:: bash + + ENV_NAME=... # YOUR_ENV_NAME + CONFIG_FILE=... # YOUR_CONFIG_FILE + OUTPUT_DIR=... # YOUR_OUTPUT_DIR + WANDB_PROJECT_NAME=... # YOUR_PROJECT_NAME + + conda activate $ENV_NAME + # install wandb + pip install wandb + # initialize wandb, copy paste your token when prompted + wandb login --host=... # your wandb hostname + + # now you are good to go + fairseq2 lm instruction_finetune $OUTPUT_DIR \ + --config-file $CONFIG_FILE \ + --config wandb_project=$WANDB_PROJECT_NAME \ + + # cleanup + conda deactivate diff --git a/doc/source/tutorials/presets.rst b/doc/source/tutorials/presets.rst new file mode 100644 index 000000000..53d87d0fb --- /dev/null +++ b/doc/source/tutorials/presets.rst @@ -0,0 +1,214 @@ +.. _tutorial-presets: + +==================================== +:octicon:`gear` Working with Presets +==================================== + +.. dropdown:: What you will learn + :icon: multi-select + :animate: fade-in + + * What presets are and why they are useful + + * How to use built-in presets + + * How to create custom presets + + * How to override preset configurations + +.. dropdown:: Prerequisites + :icon: multi-select + :animate: fade-in + + * Get familiar with fairseq2 basics (:ref:`basics-overview`) + + * Ensure you have fairseq2 installed (:ref:`installation`) + + * Familiarize yourself with recipes (:doc:`Recipe `) + + * Optionally, checkout the end to end fine-tuning tutorial (:ref:`tutorial-end-to-end-fine-tuning`) + + + +Overview +-------- + +Presets are pre-defined configurations that help you quickly get started with common training scenarios. +They encapsulate best practices and tested hyperparameters for specific use cases. +They also allows quick hyperparameter sweeps. + +The key benefits of using presets are: + +* Reduce boilerplate configuration code +* Start with proven configurations +* Easily customize for your needs +* Share configurations across experiments + + +Using Built-in Presets +---------------------- + +fairseq2 comes with several built-in presets for common scenarios. To use a preset: + +1. List available presets: + +.. code-block:: bash + + fairseq2 lm instruction_finetune --list-presets + +2. Use a preset: + +.. code-block:: bash + + fairseq2 lm instruction_finetune $OUTPUT_DIR --preset base_10h + +The preset will set default values for all configuration parameters. +You can override any of these values using ``--config``. + + +Creating Custom Presets +----------------------- + +You can create custom presets by: + +1. Define a configuration class (if not using an existing one) + +.. code-block:: python + + @dataclass(kw_only=True) + class MyTrainConfig: + """Configuration for my training task.""" + + learning_rate: float = 1e-4 + """The learning rate.""" + + batch_size: int = 32 + """The batch size.""" + + profile: tuple[int, int] | None = None + """The number of steps that the PyTorch profiler should skip and then record.""" + +2. Create a preset registry + +.. code-block:: python + + my_train_presets = ConfigRegistry[MyTrainConfig]() + + my_train_preset = my_train_presets.decorator + +3. Define presets using the decorator + +.. code-block:: python + + @my_train_preset("fast") + def _fast() -> MyTrainConfig: + return MyTrainConfig( + learning_rate=1e-3, + batch_size=64, + profile=(1000, 10), # skip 1000 steps then record 10 steps + ) + + @my_train_preset("accurate") + def _accurate() -> MyTrainConfig: + return MyTrainConfig( + learning_rate=1e-5, + batch_size=16, + profile=(1000, 10), # skip 1000 steps then record 10 steps + ) + +For a complete example of preset implementation, here are a couple of examples: + +* :mod:`fairseq2.recipes.wav2vec2.train ` + +* :mod:`fairseq2.recipes.lm.instruction_finetune ` + + +Overriding Preset Values +------------------------ + +You can override any preset values in two ways: + +1. Using command line arguments: + +.. code-block:: bash + + fairseq2 lm instruction_finetune $OUTPUT_DIR \ + --preset llama3_1_instruct \ + --config learning_rate=2e-4 batch_size=16 + +2. Using a YAML configuration file: + +.. code-block:: yaml + + # my_config.yaml + learning_rate: 2e-4 + batch_size: 16 + +.. code-block:: bash + + fairseq2 lm instruction_finetune $OUTPUT_DIR \ + --preset llama3_1_instruct \ + --config-file my_config.yaml + +The override precedence is: + +1. Command line overrides (highest priority) +2. Config file values +3. Preset defaults (lowest priority) + +Best Practices +-------------- + +* Start with an existing preset close to your use case +* Create custom presets for configurations you use frequently +* Document preset parameters and their effects +* Use meaningful preset names that indicate their purpose +* Keep presets focused on specific scenarios +* Version control your custom presets + +Go Beyond +--------- + +Once you are familiar with presets, you can go beyond and easily run hyperparameter sweeps. + +.. dropdown:: A dummy slurm example + :icon: code + :animate: fade-in + + .. code-block:: bash + + presets=( + "preset_fast" + "preset_accurate" + "preset_default" + ) + + batch_sizes=( + "16" + "32" + "64" + ) + + output_dir= + + for preset in "${presets[@]}"; do + for batch_size in "${batch_sizes[@]}"; do + echo "Running preset::$preset | batch_size::$batch_size" + srun fairseq2 train $output_dir/$preset/batch_size_$batch_size \ + --preset $preset \ + --config batch_size=$batch_size + done + done + +It will be much easier for you to manage your experiments and benchmark training speed to multiple nodes. + +.. image:: /_static/img/tutorials/presets/tutorial_presets_benchmark.png + :width: 600px + :align: center + :alt: Benchmark + +See Also +-------- + +- :doc:`Recipe ` +- :doc:`CLI ` diff --git a/doc/source/tutorials/pudb.rst b/doc/source/tutorials/pudb.rst new file mode 100644 index 000000000..1d55b66ab --- /dev/null +++ b/doc/source/tutorials/pudb.rst @@ -0,0 +1,111 @@ +.. _tutorial-pudb: + +================================== +:octicon:`bug` Debugging with PuDB +================================== + + +.. dropdown:: What you will learn + :icon: multi-select + :animate: fade-in + + * How to debug multi-node training using PuDB + + +.. dropdown:: Prerequisites + :icon: multi-select + :animate: fade-in + + * Get familiar with fairseq2 basics (:ref:`basics-overview`) + + * Ensure you have fairseq2 installed (:ref:`installation`) + + * Understand how to use CLI (:doc:`CLI `) + + * Install PuDB: ``pip install pudb`` + +This tutorial explains how to debug your training sessions, including multi-node runs, using the `PuDB debugger `. +PuDB is one of several remote debuggers you can use with fairseq2. + +Placing the debugger breakpoint in the code +------------------------------------------- + +Before setting a breakpoint, decide where in your code you want to start debugging. +Since fairseq2 supports multi-process training, ensure that the debugger is only invoked on the main process (rank 0) to prevent deadlocks. + +Insert the following code where you want to set the breakpoint: + +.. code-block:: python + + from fairseq2.gang import get_rank + if get_rank() == 0: + from pudb.remote import set_trace + + set_trace(host="meta-fairseq2", port=6899, term_size=(80*3, 24*3), reverse=True) + +**Explanation:** + +- ``host="meta-fairseq2"``: Replace with the hostname accessible to both the machine running fairseq2 and your local machine. +- ``port=6899``: Choose an appropriate port that is open and not in use. +- ``term_size=(80*3, 24*3)``: Sets the terminal size for the debugger interface. +- ``reverse=True``: Instructs the debugger to initiate the connection from the host. + + +Initializing the socket for remote debugger +------------------------------------------- + +On the host machine specified in the ``host`` parameter (`e.g.`, in our case it's ``meta-fairseq2``), run the following command to start listening on the specified port: + +.. code-block:: bash + + stty -echo -icanon && nc -l -p 6899 + +.. note:: + + - The command will appear to hang, which is expected as it's waiting for the debugger to connect. + - Ensure that the chosen port (``6899`` in this case) is open and accessible. + + +Running fairseq2 with debugger +------------------------------ + +In the other terminal / pane you need to start the fairseq2 training as usual. Here we show an example using slurm cluster. + +1. **Allocate Resources:** + + Obtain a compute allocation based on your cluster's configuration. Here's an example command using SLURM: + + .. code-block:: bash + + # Adjust the arguments (`--nodes`, `--ntasks-per-node`, etc.) as needed for your environment + salloc --nodes=1 --ntasks-per-node=8 --cpus-per-task=10 -t 1:00:00 --gpus-per-node=8 + + +2. **Start Training:** + + Launch your fairseq2 training job as you normally would. For example, for LLM training: + + .. code-block:: bash + + srun fairseq2 lm preference_finetune_w_eval $OUTPUT_DIR --no-sweep-dir --config-file $CONFIG_YAML + +3. **Connect to the Debugger:** + + Once the training reaches the breakpoint, the PuDB interface will appear in the terminal where you initialized the socket. + +Example screenshot of the debugger: + +.. image:: ../_static/img/tutorials/pudb.png + :align: center + :alt: PuDB example + :width: 600 + +Please refer to the `PuDB docs and repo `_ to explore more features and familiarize yourself with the interface. +PuDB supports all standard ``pdb`` commands in the source view and offers additional functionality for an enhanced debugging experience. + + +Exiting the debugger +-------------------- + +Press ``q`` to quit the debugger. +This will terminate the socket session and stop the training job. diff --git a/doc/static/img/logo.png b/doc/static/img/logo.png deleted file mode 100644 index 75472cbb5..000000000 Binary files a/doc/static/img/logo.png and /dev/null differ diff --git a/doc/templates/autosummary/class.rst b/doc/templates/autosummary/class.rst deleted file mode 100644 index 61d94d463..000000000 --- a/doc/templates/autosummary/class.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. currentmodule:: {{ module }} - -{{ name | escape | underline }} - -.. autoclass:: {{ name }} - :members: - :member-order: groupwise - :class-doc-from: both - :special-members: __call__, __iter__ - :inherited-members: Module - :show-inheritance: diff --git a/doc/templates/autosummary/data.rst b/doc/templates/autosummary/data.rst deleted file mode 100644 index c389511b4..000000000 --- a/doc/templates/autosummary/data.rst +++ /dev/null @@ -1,5 +0,0 @@ -.. currentmodule:: {{ module }} - -{{ name | escape | underline }} - -.. autodata:: {{ name }} diff --git a/doc/templates/autosummary/function.rst b/doc/templates/autosummary/function.rst deleted file mode 100644 index 0712f2eb2..000000000 --- a/doc/templates/autosummary/function.rst +++ /dev/null @@ -1,5 +0,0 @@ -.. currentmodule:: {{ module }} - -{{ name | escape | underline }} - -.. autofunction:: {{ name }} diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index e00b15327..ad1f3706b 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -6,7 +6,7 @@ cmake_minimum_required(VERSION 3.21.0) -project(fairseq2n VERSION 0.3.0 LANGUAGES C CXX) +project(fairseq2n VERSION 0.4.0 LANGUAGES C CXX) if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) set_property(CACHE CMAKE_BUILD_TYPE PROPERTY VALUE RelWithDebInfo) diff --git a/native/python/setup.py b/native/python/setup.py index ac6671c4e..277d0d684 100644 --- a/native/python/setup.py +++ b/native/python/setup.py @@ -7,7 +7,7 @@ from __future__ import annotations from os import path -from typing import Final, List, Optional +from typing import Final import torch from setuptools import Command, find_packages, setup @@ -88,7 +88,7 @@ def run(self) -> None: self._cmake_install(component="python") - def _cmake_install(self, component: Optional[str] = None) -> None: + def _cmake_install(self, component: str | None = None) -> None: cmd = ["cmake", "--install", self.cmake_build_dir] if component: @@ -101,7 +101,7 @@ def _cmake_install(self, component: Optional[str] = None) -> None: self.spawn(cmd) - def get_outputs(self) -> List[str]: + def get_outputs(self) -> list[str]: outputs = [] if self.bundle_lib: @@ -125,7 +125,7 @@ def get_outputs(self) -> List[str]: return outputs - def get_inputs(self) -> List[str]: + def get_inputs(self) -> list[str]: # We take no input. return [] @@ -137,7 +137,7 @@ def get_inputs(self) -> List[str]: "install_cmake": install_cmake, }, name="fairseq2n", - version="0.3.0.dev0", + version="0.4.0.dev0", description="FAIR Sequence Modeling Toolkit (Native)", long_description="https://github.com/facebookresearch/fairseq2", long_description_content_type="text/plain", diff --git a/native/python/src/fairseq2n/__init__.py b/native/python/src/fairseq2n/__init__.py index 195418687..ae431c84e 100644 --- a/native/python/src/fairseq2n/__init__.py +++ b/native/python/src/fairseq2n/__init__.py @@ -6,14 +6,13 @@ from __future__ import annotations -__version__ = "0.3.0.dev0" +__version__ = "0.4.0.dev0" import platform import site from ctypes import CDLL, RTLD_GLOBAL from os import environ from pathlib import Path -from typing import List, Optional, Tuple from fairseq2n.config import ( _CUDA_VERSION, @@ -59,7 +58,7 @@ def supports_cuda() -> bool: return _SUPPORTS_CUDA -def cuda_version() -> Optional[Tuple[int, int]]: +def cuda_version() -> tuple[int, int] | None: """Return the version of CUDA that fairseq2n supports. :returns: @@ -74,7 +73,7 @@ def cuda_version() -> Optional[Tuple[int, int]]: # Keeps the shared libraries that we load using our own extended lookup logic # in memory. -_libs: List[CDLL] = [] +_libs: list[CDLL] = [] def _load_shared_libraries() -> None: @@ -124,7 +123,7 @@ def _load_sndfile() -> None: _libs.append(libsndfile) -def _load_shared_library(lib_name: str) -> Optional[CDLL]: +def _load_shared_library(lib_name: str) -> CDLL | None: # In Conda environments, we always expect native libraries to be part of the # environment, so we skip the default lookup rules of the dynamic linker. if not "CONDA_PREFIX" in environ: diff --git a/native/python/src/fairseq2n/bindings/data/data_pipeline.cc b/native/python/src/fairseq2n/bindings/data/data_pipeline.cc index b5531eddc..deb6f4891 100644 --- a/native/python/src/fairseq2n/bindings/data/data_pipeline.cc +++ b/native/python/src/fairseq2n/bindings/data/data_pipeline.cc @@ -484,6 +484,7 @@ def_data_pipeline(py::module_ &data_module) data_pipeline_builder &self, float64 threshold, cost_fn fn, + std::optional maybe_bucket_fn, std::optional maybe_min_num_examples, std::optional maybe_max_num_examples, bool drop_remainder) -> data_pipeline_builder & @@ -491,6 +492,7 @@ def_data_pipeline(py::module_ &data_module) self = std::move(self).dynamic_bucket( threshold, std::move(fn), + std::move(maybe_bucket_fn), maybe_min_num_examples, maybe_max_num_examples, drop_remainder); @@ -499,6 +501,7 @@ def_data_pipeline(py::module_ &data_module) }, py::arg("threshold"), py::arg("fn"), + py::arg("bucket_creation_fn") = std::nullopt, py::arg("min_num_examples") = std::nullopt, py::arg("max_num_examples") = std::nullopt, py::arg("drop_remainder") = false) diff --git a/native/python/src/fairseq2n/bindings/data/text/converters.cc b/native/python/src/fairseq2n/bindings/data/text/converters.cc index 365296aaf..179862d20 100644 --- a/native/python/src/fairseq2n/bindings/data/text/converters.cc +++ b/native/python/src/fairseq2n/bindings/data/text/converters.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -27,6 +28,29 @@ using namespace fairseq2n::detail; namespace fairseq2n { +static std::shared_ptr +make_string_splitter( + std::string_view sep, + std::optional> &&maybe_names, + std::variant> &&indices, + bool exclude) +{ + if (sep.size() != 1) + throw_( + "`sep` must be of length 1, but is of length {} instead.", sep.size()); + + std::vector names{}; + if (maybe_names) + names = *std::move(maybe_names); + + return std::visit( + [&](auto &&idx) { + return std::make_shared( + sep[0], std::move(names), std::forward(idx), exclude); + }, + std::move(indices)); +} + void def_text_converters(py::module_ &text_module) { @@ -41,25 +65,29 @@ def_text_converters(py::module_ &text_module) std::optional> maybe_indices, bool exclude) { - if (sep.size() != 1) - throw_( - "`sep` must be of length 1, but is of length {} instead.", sep.size()); - - std::vector names{}; - if (maybe_names) - names = *std::move(maybe_names); - std::vector indices{}; if (maybe_indices) indices = *std::move(maybe_indices); - return std::make_shared( - sep[0], std::move(names), std::move(indices), exclude); + return make_string_splitter(sep, std::move(maybe_names), std::move(indices), exclude); }), py::arg("sep") = '\t', py::arg("names") = std::nullopt, py::arg("indices") = std::nullopt, py::arg("exclude") = false) + .def( + py::init([]( + std::string_view sep, + std::optional> maybe_names, + std::size_t index, + bool exclude) + { + return make_string_splitter(sep, std::move(maybe_names), index, exclude); + }), + py::arg("sep") = '\t', + py::arg("names") = std::nullopt, + py::arg("indices") = 0, + py::arg("exclude") = false) .def("__call__", &string_splitter::operator(), py::call_guard{}); // StrToIntConverter diff --git a/native/src/fairseq2n/data/data_pipeline.cc b/native/src/fairseq2n/data/data_pipeline.cc index a6debd39f..ee6b4edd7 100644 --- a/native/src/fairseq2n/data/data_pipeline.cc +++ b/native/src/fairseq2n/data/data_pipeline.cc @@ -407,6 +407,7 @@ data_pipeline_builder data_pipeline_builder::dynamic_bucket( float64 threshold, cost_fn fn, + std::optional maybe_bucket_fn, std::optional maybe_min_num_examples, std::optional maybe_max_num_examples, bool drop_remainder) && @@ -421,12 +422,17 @@ data_pipeline_builder::dynamic_bucket( throw_("`max_num_examples` must be greater than or equal to `min_num_examples`."); } - factory_ = [=, fn = std::move(fn), inner = std::move(factory_)]() mutable + factory_ = [ + =, + fn = std::move(fn), + maybe_bucket_fn = std::move(maybe_bucket_fn), + inner = std::move(factory_)]() mutable { return std::make_unique( inner(), threshold, std::move(fn), + std::move(maybe_bucket_fn), maybe_min_num_examples, maybe_max_num_examples, drop_remainder); diff --git a/native/src/fairseq2n/data/data_pipeline.h b/native/src/fairseq2n/data/data_pipeline.h index da296cec3..849a2439c 100644 --- a/native/src/fairseq2n/data/data_pipeline.h +++ b/native/src/fairseq2n/data/data_pipeline.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -111,14 +112,16 @@ class FAIRSEQ2_API data_pipeline { mutable bool is_broken_ = false; }; +using bucket_creation_fn = std::function, data_list>(data_list &&)>; + +using cost_fn = std::function; + using data_length_fn = std::function; using map_fn = std::function; using predicate_fn = std::function; -using cost_fn = std::function; - using yield_fn = std::function; class FAIRSEQ2_API data_pipeline_builder { @@ -152,6 +155,7 @@ class FAIRSEQ2_API data_pipeline_builder { dynamic_bucket( float64 threshold, cost_fn fn, + std::optional maybe_bucket_fn = std::nullopt, std::optional maybe_min_num_examples = std::nullopt, std::optional maybe_max_num_examples = std::nullopt, bool drop_remainder = false) &&; diff --git a/native/src/fairseq2n/data/dynamic_bucket_data_source.cc b/native/src/fairseq2n/data/dynamic_bucket_data_source.cc index e00dc09de..13b2eff1a 100644 --- a/native/src/fairseq2n/data/dynamic_bucket_data_source.cc +++ b/native/src/fairseq2n/data/dynamic_bucket_data_source.cc @@ -14,12 +14,14 @@ dynamic_bucket_data_source::dynamic_bucket_data_source( std::unique_ptr &&inner, float64 threshold, cost_fn &&fn, + std::optional &&maybe_bucket_fn, std::optional maybe_min_num_examples, std::optional maybe_max_num_examples, bool drop_remainder) noexcept : inner_{std::move(inner)}, threshold_{threshold}, cost_fn_{std::move(fn)}, + maybe_bucket_creation_fn_{std::move(maybe_bucket_fn)}, maybe_min_num_examples_{maybe_min_num_examples}, maybe_max_num_examples_{maybe_max_num_examples}, drop_remainder_{drop_remainder} @@ -28,10 +30,14 @@ dynamic_bucket_data_source::dynamic_bucket_data_source( std::optional dynamic_bucket_data_source::next() { - data_list output{}; + if (!return_buffer_.empty()) { + data output{return_buffer_.front()}; + return_buffer_.pop_front(); + return output; + } if (maybe_min_num_examples_) - output.reserve(*maybe_min_num_examples_); + buffer_.reserve(*maybe_min_num_examples_); float64 cost = 0; @@ -40,13 +46,13 @@ dynamic_bucket_data_source::next() bool minimum_size_met = true; if (maybe_min_num_examples_) - minimum_size_met = output.size() >= *maybe_min_num_examples_; + minimum_size_met = buffer_.size() >= *maybe_min_num_examples_; if (cost_threshold_met && minimum_size_met) return true; bool maximum_size_met = false; if (maybe_max_num_examples_) - maximum_size_met = output.size() >= *maybe_max_num_examples_; + maximum_size_met = buffer_.size() >= *maybe_max_num_examples_; return maximum_size_met; }; @@ -56,33 +62,58 @@ dynamic_bucket_data_source::next() if (!maybe_example) break; cost += cost_fn_(*maybe_example); - output.push_back(*std::move(maybe_example)); + buffer_.push_back(*std::move(maybe_example)); } - if (output.empty()) + if (buffer_.empty()) return std::nullopt; - if (drop_remainder_ && !bucket_ready()) + if (bucket_ready()) { + if (maybe_bucket_creation_fn_) { + const bucket_creation_fn& fn = *maybe_bucket_creation_fn_; + auto&& [return_buffer, new_buffer] = fn(std::move(buffer_)); + + buffer_ = std::move(new_buffer); + + data output{return_buffer.front()}; + return_buffer.pop_front(); + + return_buffer_ = std::move(return_buffer); + + return output; + } + } else if (drop_remainder_) { + buffer_.clear(); return std::nullopt; + } + data_list output = std::move(buffer_); + buffer_.clear(); return output; } void dynamic_bucket_data_source::reset(bool reset_rng) { + buffer_.clear(); inner_->reset(reset_rng); } void dynamic_bucket_data_source::record_position(tape &t, bool strict) const { + if (maybe_bucket_creation_fn_) { + t.record(buffer_); + } inner_->record_position(t, strict); } void dynamic_bucket_data_source::reload_position(tape &t, bool strict) { + if (maybe_bucket_creation_fn_) { + buffer_ = t.read(); + } inner_->reload_position(t, strict); } diff --git a/native/src/fairseq2n/data/dynamic_bucket_data_source.h b/native/src/fairseq2n/data/dynamic_bucket_data_source.h index 31c4624b2..341e598ec 100644 --- a/native/src/fairseq2n/data/dynamic_bucket_data_source.h +++ b/native/src/fairseq2n/data/dynamic_bucket_data_source.h @@ -22,6 +22,7 @@ class dynamic_bucket_data_source final : public data_source { std::unique_ptr &&inner, float64 threshold, cost_fn &&fn, + std::optional &&maybe_bucket_fn, std::optional maybe_min_num_examples, std::optional maybe_max_num_examples, bool drop_remainder) noexcept; @@ -45,9 +46,14 @@ class dynamic_bucket_data_source final : public data_source { std::unique_ptr inner_; float64 threshold_; cost_fn cost_fn_; + std::optional maybe_bucket_creation_fn_; std::optional maybe_min_num_examples_; std::optional maybe_max_num_examples_; bool drop_remainder_; + + data_list buffer_{}; + std::deque return_buffer_{}; + }; } // namespace fairseq2n::detail diff --git a/native/src/fairseq2n/data/text/string_splitter.cc b/native/src/fairseq2n/data/text/string_splitter.cc index 15ed94094..542bdd8dc 100644 --- a/native/src/fairseq2n/data/text/string_splitter.cc +++ b/native/src/fairseq2n/data/text/string_splitter.cc @@ -22,12 +22,35 @@ string_splitter::string_splitter( std::vector names, std::vector indices, bool exclude) - : separator_{separator}, names_(std::move(names)), indices_{std::move(indices)}, exclude_{exclude} + : separator_{separator}, + names_{std::move(names)}, + indices_{std::move(indices)}, + exclude_{exclude}, + single_column_{false} { + finalize_indices(); +} + +string_splitter::string_splitter( + char separator, + std::vector names, + std::size_t index, + bool exclude) + : separator_{separator}, + names_{std::move(names)}, + indices_{index}, + exclude_{exclude}, + single_column_{true} +{ + finalize_indices(); +} + +void +string_splitter::finalize_indices() { if (indices_.empty()) return; - if (!names_.empty() && !exclude && names_.size() != indices_.size()) + if (!names_.empty() && !exclude_ && names_.size() != indices_.size()) throw_( "`names` and `indices` must have the same length, but have the lengths {} and {} instead.", names_.size(), indices_.size()); @@ -76,9 +99,13 @@ string_splitter::operator()(data &&d) const throw_( "The input string must have at least {} field(s), but has {} instead.", indices_.back(), idx); - // If no names specified, return as list. - if (names_.empty()) + // If no names specified, return a list, or a string if a single non-excluding index is specified. + if (names_.empty()) { + if (single_column_ && !exclude_) + return data{std::move(fields[0])}; + return fields; + } // Otherwise, as dictionary. if (names_.size() != fields.size()) diff --git a/native/src/fairseq2n/data/text/string_splitter.h b/native/src/fairseq2n/data/text/string_splitter.h index 0072dd6ff..dba38f247 100644 --- a/native/src/fairseq2n/data/text/string_splitter.h +++ b/native/src/fairseq2n/data/text/string_splitter.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "fairseq2n/api.h" @@ -25,6 +26,13 @@ class FAIRSEQ2_API string_splitter final { std::vector indices = {}, bool exclude = false); + explicit + string_splitter( + char separator = '\t', + std::vector names = {}, + std::size_t index = 0, + bool exclude = false); + data operator()(data &&d) const; @@ -33,6 +41,9 @@ class FAIRSEQ2_API string_splitter final { std::vector names_; std::vector indices_; bool exclude_; + bool single_column_; + + void finalize_indices(); }; } // namespace fairseq2n diff --git a/pyproject.toml b/pyproject.toml index d4b7cc881..a119bed9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ disable_error_code = "type-abstract,typeddict-unknown-key" disallow_untyped_calls = false disallow_untyped_decorators = false files = "setup.py,src,tests" -python_version = 3.8 +python_version = "3.10" show_error_codes = true show_error_context = true strict = true @@ -30,7 +30,7 @@ warn_unused_configs = false warn_unused_ignores = false [[tool.mypy.overrides]] -module = "torch.distributed.*" +module = "torch.distributed.*,torch.optim.*" implicit_reexport = true # TODO: fix! diff --git a/setup.py b/setup.py index b4f9012f2..50ab008a8 100644 --- a/setup.py +++ b/setup.py @@ -8,14 +8,17 @@ from setuptools import find_namespace_packages, setup -version = "0.3.0.dev0" +version = "0.4.0.dev0" # If this is a local development install, allow nightly fairseq2n builds to # take precedence. if version.endswith(".dev0"): fairseq2n_version_spec = f">={version},<={version[:-5]}" else: - fairseq2n_version_spec = f"=={version}" + p = version.split("+", maxsplit=1) + + fairseq2n_version_spec = "==" + p[0] + setup( name="fairseq2", @@ -33,10 +36,9 @@ "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], package_dir={"": "src"}, @@ -46,15 +48,15 @@ "fairseq2.assets.cards": ["**/*.yaml"], }, zip_safe=False, - python_requires=">=3.8", + python_requires=">=3.10", install_requires=[ - "blobfile~=2.1", "editdistance~=0.8", "fairseq2n" + fairseq2n_version_spec, "importlib_metadata~=7.0", "importlib_resources~=6.4", + "mypy-extensions~=1.0", "numpy~=1.23", - "packaging~=23.1", + "packaging~=24.1", "psutil~=5.9", "pyyaml~=6.0", "rich~=13.7", @@ -62,7 +64,11 @@ "tiktoken~=0.7", "torcheval~=0.0.6", "tqdm~=4.62", - "typing_extensions~=4.3;python_version<'3.10'", + "typing_extensions~=4.12", + # This dependency is required for tiktoken.load.read_file, but it's + # listed as optional in tiktoken's pyproject.toml + # (https://github.com/openai/tiktoken/blob/main/pyproject.toml#L9) + "blobfile~=3.0.0", ], extras_require={ "arrow": ["pyarrow>=13.0.0", "pandas~=2.0.0"], diff --git a/src/fairseq2/__init__.py b/src/fairseq2/__init__.py index 4a49f4015..2c2804273 100644 --- a/src/fairseq2/__init__.py +++ b/src/fairseq2/__init__.py @@ -6,53 +6,16 @@ from __future__ import annotations -__version__ = "0.3.0.dev0" +__version__ = "0.4.0.dev0" import fairseq2n # Report any fairseq2n initialization error eagerly. # isort: split -import fairseq2.datasets import fairseq2.models # isort: split -import os +from fairseq2.setup import setup_fairseq2 as setup_fairseq2 -from importlib_metadata import entry_points - -from fairseq2.logging import get_log_writer - -log = get_log_writer(__name__) - -_setup_complete = False - - -def setup_extensions() -> None: - global _setup_complete - - if _setup_complete: - return - - # Mark as complete early on to avoid recursive calls. - _setup_complete = True - - for entry_point in entry_points(group="fairseq2"): - try: - setup_extension = entry_point.load() - - setup_extension() - except TypeError: - raise RuntimeError( - f"The entry point '{entry_point.value}' is not a valid fairseq2 setup function." - ) from None - except Exception as ex: - if "FAIRSEQ2_EXTENSION_TRACE" in os.environ: - raise RuntimeError( - f"The setup function at '{entry_point.value}' has failed. See nested exception for details." - ) from ex - - log.warning( - "The setup function at '{}' has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", - entry_point.value, - ) +setup_extensions = setup_fairseq2 # compat diff --git a/src/fairseq2/assets/__init__.py b/src/fairseq2/assets/__init__.py index 087f5db45..b13739375 100644 --- a/src/fairseq2/assets/__init__.py +++ b/src/fairseq2/assets/__init__.py @@ -7,10 +7,6 @@ from __future__ import annotations from fairseq2.assets.card import AssetCard as AssetCard -from fairseq2.assets.card import AssetCardError as AssetCardError -from fairseq2.assets.card import ( - AssetCardFieldNotFoundError as AssetCardFieldNotFoundError, -) from fairseq2.assets.download_manager import AssetDownloadError as AssetDownloadError from fairseq2.assets.download_manager import ( AssetDownloadManager as AssetDownloadManager, @@ -18,18 +14,23 @@ from fairseq2.assets.download_manager import ( InProcAssetDownloadManager as InProcAssetDownloadManager, ) -from fairseq2.assets.download_manager import ( - default_download_manager as default_download_manager, +from fairseq2.assets.error import AssetCardError as AssetCardError +from fairseq2.assets.error import ( + AssetCardFieldNotFoundError as AssetCardFieldNotFoundError, ) +from fairseq2.assets.error import AssetCardNotFoundError as AssetCardNotFoundError from fairseq2.assets.error import AssetError as AssetError +from fairseq2.assets.error import AssetNotFoundError as AssetNotFoundError from fairseq2.assets.metadata_provider import ( AbstractAssetMetadataProvider as AbstractAssetMetadataProvider, ) from fairseq2.assets.metadata_provider import AssetMetadataError as AssetMetadataError +from fairseq2.assets.metadata_provider import ( + AssetMetadataNotFoundError as AssetMetadataNotFoundError, +) from fairseq2.assets.metadata_provider import ( AssetMetadataProvider as AssetMetadataProvider, ) -from fairseq2.assets.metadata_provider import AssetNotFoundError as AssetNotFoundError from fairseq2.assets.metadata_provider import ( FileAssetMetadataProvider as FileAssetMetadataProvider, ) @@ -39,8 +40,14 @@ from fairseq2.assets.metadata_provider import ( PackageAssetMetadataProvider as PackageAssetMetadataProvider, ) +from fairseq2.assets.metadata_provider import PackageFileLister as PackageFileLister +from fairseq2.assets.metadata_provider import ( + WheelPackageFileLister as WheelPackageFileLister, +) from fairseq2.assets.metadata_provider import load_metadata_file as load_metadata_file from fairseq2.assets.store import AssetStore as AssetStore from fairseq2.assets.store import EnvironmentResolver as EnvironmentResolver from fairseq2.assets.store import StandardAssetStore as StandardAssetStore from fairseq2.assets.store import default_asset_store as default_asset_store +from fairseq2.assets.store import get_asset_dir as get_asset_dir +from fairseq2.assets.store import get_user_asset_dir as get_user_asset_dir diff --git a/src/fairseq2/assets/card.py b/src/fairseq2/assets/card.py index 75f0379a1..388c8b8a2 100644 --- a/src/fairseq2/assets/card.py +++ b/src/fairseq2/assets/card.py @@ -8,25 +8,18 @@ import os import re +from collections.abc import Mapping, MutableMapping, Set, Sized from pathlib import Path -from typing import ( - AbstractSet, - Any, - Dict, - Final, - List, - Mapping, - MutableMapping, - Optional, - TypeVar, - final, -) +from typing import Any, Final, cast, final from urllib.parse import urlparse, urlunparse -from fairseq2.assets.error import AssetError -from fairseq2.utils.value_converter import ValueConverter, default_value_converter - -T = TypeVar("T") +from fairseq2.assets.error import AssetCardError, AssetCardFieldNotFoundError +from fairseq2.error import InternalError +from fairseq2.utils.structured import ( + StructureError, + default_value_converter, + unstructure, +) @final @@ -34,16 +27,16 @@ class AssetCard: """Holds information about an asset.""" _name: str - _metadata: MutableMapping[str, Any] - _base: Optional[AssetCard] - _value_converter: ValueConverter + _metadata: MutableMapping[str, object] + _base_card: AssetCard | None + _base_path: Path | None def __init__( self, - metadata: MutableMapping[str, Any], - base: Optional[AssetCard] = None, - *, - value_converter: Optional[ValueConverter] = None, + name: str, + metadata: MutableMapping[str, object], + base_card: AssetCard | None = None, + base_path: Path | None = None, ) -> None: """ :param metadata: @@ -51,31 +44,16 @@ def __init__( contain a specific piece of information about the asset. :param base: The card that this card derives from. - :param value_converter: - The :class:`ValueConverter` instance to use. If ``None``, the - default instance will be used. """ - try: - name = metadata["name"] - except KeyError: - raise AssetCardError( - "`metadata` must contain a key named 'name'." - ) from None - - if not isinstance(name, str): - raise AssetCardError( - f"The value of 'name' in `metadata` must be of type `{str}`, but is of type `{type(name)}` instead." - ) - self._name = name self._metadata = metadata - self._base = base - self._value_converter = value_converter or default_value_converter + self._base_card = base_card + self._base_path = base_path def field(self, name: str) -> AssetCardField: """Return a field of this card. - If the card does not contain the specified field, its base card will be + If the card does not contain the specified field, its base cards will be checked recursively. :param name: @@ -83,10 +61,11 @@ def field(self, name: str) -> AssetCardField: """ return AssetCardField(self, path=[name]) - def _get_field_value(self, name: str, path: List[str]) -> Any: - assert len(path) > 0 + def _get_field_value(self, leaf_card: AssetCard, path: list[str]) -> object: + if len(path) == 0: + raise InternalError("`path` has zero length.") - metadata = self._metadata + metadata: object = self._metadata contains = True @@ -96,11 +75,11 @@ def _get_field_value(self, name: str, path: List[str]) -> Any: break - if not isinstance(metadata, Mapping): + if not isinstance(metadata, MutableMapping): pathname = ".".join(path) raise AssetCardFieldNotFoundError( - f"The asset card '{name}' must have a field named '{pathname}'." + leaf_card.name, f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." # fmt: skip ) try: @@ -111,43 +90,63 @@ def _get_field_value(self, name: str, path: List[str]) -> Any: break if not contains: - if self._base is not None: - return self._base._get_field_value(name, path) + if self._base_card is not None: + return self._base_card._get_field_value(leaf_card, path) pathname = ".".join(path) raise AssetCardFieldNotFoundError( - f"The asset card '{name}' must have a field named '{pathname}'." + leaf_card.name, f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." # fmt: skip ) return metadata - def _set_field_value(self, path: List[str], value: Any) -> None: - assert len(path) > 0 + def _set_field_value(self, path: list[str], value: object) -> None: + if len(path) == 0: + raise InternalError("`path` has zero length.") metadata = self._metadata for depth, field in enumerate(path[:-1]): - try: - metadata = metadata[field] - except KeyError: - tmp: Dict[str, Any] = {} + value_ = metadata.get(field) + if value_ is None: + tmp: dict[str, object] = {} metadata[field] = tmp - metadata = tmp + value_ = tmp - if not isinstance(metadata, Mapping): + if not isinstance(value_, MutableMapping): conflict_pathname = ".".join(path[: depth + 1]) pathname = ".".join(path) raise AssetCardError( - f"The asset card '{self._name}' cannot have a field named '{pathname}' due to path conflict at '{conflict_pathname}'." + self._name, f"The '{self._name}' asset card cannot have a field named '{pathname}' due to path conflict at '{conflict_pathname}'." # fmt: skip ) + metadata = value_ + metadata[path[-1]] = value + def flatten(self) -> AssetCard: + """ + Flattens the metadata of this card and all its bases into a single one. + """ + all_metadata = [] + + card: AssetCard | None = self + + while card is not None: + all_metadata.append(card._metadata) + + card = card._base_card + + for metadata in all_metadata[-2::-1]: + all_metadata[-1].update(metadata) + + return AssetCard(self._name, all_metadata[-1]) + def __repr__(self) -> str: return repr(self._metadata) @@ -157,14 +156,14 @@ def name(self) -> str: return self._name @property - def metadata(self) -> Mapping[str, Any]: + def metadata(self) -> Mapping[str, object]: """The metadata of the asset.""" return self._metadata @property - def base(self) -> Optional[AssetCard]: + def base(self) -> AssetCard | None: """The card that this card derives from.""" - return self._base + return self._base_card @final @@ -172,9 +171,9 @@ class AssetCardField: """Represents a field of an asset card.""" _card: AssetCard - _path: List[str] + _path: list[str] - def __init__(self, card: AssetCard, path: List[str]) -> None: + def __init__(self, card: AssetCard, path: list[str]) -> None: """ :param card: The card owning this field. @@ -195,48 +194,48 @@ def field(self, name: str) -> AssetCardField: def exists(self) -> bool: """Return ``True`` if the field exists.""" try: - self._card._get_field_value(self._card.name, self._path) - - return True + self._card._get_field_value(self._card, self._path) except AssetCardFieldNotFoundError: return False - def as_(self, type_hint: Any, *, allow_empty: bool = False) -> Any: + return True + + def as_unstructured(self) -> object: + """Return the value of this field in unstructured form.""" + return self._card._get_field_value(self._card, self._path) + + def as_(self, type_: object, *, allow_empty: bool = False) -> Any: """Return the value of this field. - :param type_hint: - The type hint of the field. + :param type_: + The type expression of the field. :param allow_empty: If ``True``, allows the field to be empty. """ - unstructured_value = self._card._get_field_value(self._card.name, self._path) + unstructured_value = self._card._get_field_value(self._card, self._path) try: - value = self._card._value_converter.structure(unstructured_value, type_hint) - except ValueError as ex: - raise ValueError( - "`type_hint` must be a supported type annotation. See nested exception for details." - ) from ex - except TypeError as ex: + value = default_value_converter.structure(unstructured_value, type_) + except StructureError as ex: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the field '{pathname}' of the asset card '{self._card.name}' cannot be retrieved as `{type_hint}`. See nested exception for details." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card cannot be parsed as `{type_}`. See the nested exception for details." # fmt: skip ) from ex if value is None: return value - if not allow_empty and not value: + if not allow_empty and isinstance(value, Sized) and len(value) == 0: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the field '{pathname}' of the asset card '{self._card.name}' must not be empty." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is empty." # fmt: skip ) return value - def as_one_of(self, valid_values: AbstractSet[str]) -> str: + def as_one_of(self, valid_values: Set[str]) -> str: """Return the value of this field as one of the values in ``valid_values`` :param values: @@ -245,7 +244,7 @@ def as_one_of(self, valid_values: AbstractSet[str]) -> str: if not valid_values: raise ValueError("`valid_values` must not be empty.") - value = self.as_(str) + value = cast(str, self.as_(str)) if value not in valid_values: pathname = ".".join(self._path) @@ -254,82 +253,76 @@ def as_one_of(self, valid_values: AbstractSet[str]) -> str: values.sort() + s = ", ".join(values) + raise AssetCardError( - f"The value of the field '{pathname}' of the asset card '{self._card.name}' must be one of {repr(values)}, but is {repr(value)} instead." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be one of the following values, but is '{value}' instead: {s}" # fmt: skip ) - return value # type: ignore[no-any-return] + return value def as_uri(self) -> str: """Return the value of this field as a URI.""" - value = self.as_(str) + value = cast(str, self.as_(str)) try: if not _starts_with_scheme(value): path = Path(value) - if not path.is_absolute(): - base_path = self._card.metadata.get("__base_path__") - if base_path is not None: - path = base_path.joinpath(path) + if not path.is_absolute() and self._card._base_path is not None: + path = self._card._base_path.joinpath(path) return path.as_uri() - return urlunparse(urlparse(value)) # type: ignore[no-any-return] - except ValueError as ex: + return urlunparse(urlparse(value)) + except ValueError: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the field '{pathname}' of the asset card '{self._card.name}' must be a URI or an absolute pathname, but is '{value}' instead." - ) from ex + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a URI or an absolute pathname, but is '{value}' instead." # fmt: skip + ) from None def as_filename(self) -> str: """Return the value of this field as a filename.""" - value = self.as_(str) + value = cast(str, self.as_(str)) if os.sep in value or (os.altsep and os.altsep in value): pathname = ".".join(self._path) raise AssetCardError( - f"The value of the field '{pathname}' of the asset card '{self._card.name}' must be a filename, but is '{value}' instead." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a filename, but is '{value}' instead." # fmt: skip ) - return value # type: ignore[no-any-return] + return value def get_as_( - self, type_hint: Any, default: Optional[T] = None, *, allow_empty: bool = False + self, type_: object, default: object = None, *, allow_empty: bool = False ) -> Any: """Return the value of this field if it exists; otherwise, return ``default``. + :param type_: + The type expression of the field. :param default: The default value. :param allow_empty: If ``True``, allows the field to be empty. """ try: - return self.as_(type_hint, allow_empty=True) + return self.as_(type_, allow_empty=allow_empty) except AssetCardFieldNotFoundError: return default - def set(self, value: Any) -> None: + def set(self, value: object) -> None: """Set the value of this field.""" try: - unstructured_value = self._card._value_converter.unstructure(value) - except TypeError as ex: - raise TypeError( - "`value` must be of a supported type. See nested exception for details." + unstructured_value = unstructure(value) + except StructureError as ex: + raise ValueError( + "`value` must be of a type that can be unstructured. See the nested exception for details." ) from ex self._card._set_field_value(self._path, unstructured_value) -class AssetCardError(AssetError): - """Raised when an asset card operation fails.""" - - -class AssetCardFieldNotFoundError(AssetCardError): - """Raised when an asset card field cannot be found.""" - - _SCHEME_REGEX: Final = re.compile("^[a-zA-Z0-9]+://") diff --git a/src/fairseq2/assets/cards/datasets/librispeech.yaml b/src/fairseq2/assets/cards/datasets/librispeech.yaml index 509f128ad..e38c6bb22 100644 --- a/src/fairseq2/assets/cards/datasets/librispeech.yaml +++ b/src/fairseq2/assets/cards/datasets/librispeech.yaml @@ -7,9 +7,14 @@ name: librispeech_asr dataset_family: generic_asr tokenizer: "https://dl.fbaipublicfiles.com/fairseq/wav2vec/librispeech_asr.model" -tokenizer_family: librispeech_asr +tokenizer_family: char_tokenizer --- name: librispeech_asr_100h base: librispeech_asr + +--- + +name: librispeech_960h +dataset_family: generic_speech diff --git a/src/fairseq2/assets/cards/models/jepa.yaml b/src/fairseq2/assets/cards/models/jepa.yaml new file mode 100644 index 000000000..f067f9543 --- /dev/null +++ b/src/fairseq2/assets/cards/models/jepa.yaml @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +name: jepa_vitl16 +model_family: jepa +model_arch: large +model_config: + encoder_config: + input_dims: [16, 224, 224] + patch_dims: [2, 16, 16] + uniform_power: true +checkpoint: "https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar" + +--- + +name: jepa_vith16 +model_family: jepa +model_arch: huge +model_config: + encoder_config: + input_dims: [16, 224, 224] + patch_dims: [2, 16, 16] + uniform_power: true +checkpoint: "https://dl.fbaipublicfiles.com/jepa/vith16/vith16.pth.tar" + +--- + +name: jepa_vith16_384 +model_family: jepa +model_arch: huge +model_config: + encoder_config: + input_dims: [16, 384, 384] + patch_dims: [2, 16, 16] + uniform_power: true +checkpoint: "https://dl.fbaipublicfiles.com/jepa/vith16-384/vith16-384.pth.tar" diff --git a/src/fairseq2/assets/cards/models/llama.yaml b/src/fairseq2/assets/cards/models/llama.yaml index 93abeb080..8d739c69e 100644 --- a/src/fairseq2/assets/cards/models/llama.yaml +++ b/src/fairseq2/assets/cards/models/llama.yaml @@ -97,7 +97,6 @@ name: llama3_70b base: llama3 model_arch: llama3_70b num_shards: 8 -shard_embed_dim: false --- @@ -105,7 +104,6 @@ name: llama3_70b_instruct base: llama3_instruct model_arch: llama3_70b num_shards: 8 -shard_embed_dim: false --- @@ -125,7 +123,13 @@ name: llama3_1_70b base: llama3 model_arch: llama3_1_70b num_shards: 8 -shard_embed_dim: false + +--- + +name: llama3_3_70b_instruct +base: llama3_instruct +model_arch: llama3_1_70b +num_shards: 8 --- @@ -133,4 +137,27 @@ name: llama3_1_70b_instruct base: llama3_instruct model_arch: llama3_1_70b num_shards: 8 -shard_embed_dim: false + +--- + +name: llama3_2_1b +base: llama3 +model_arch: llama3_2_1b + +--- + +name: llama3_2_1b_instruct +base: llama3_instruct +model_arch: llama3_2_1b + +--- + +name: llama3_2_3b +base: llama3 +model_arch: llama3_2_3b + +--- + +name: llama3_2_3b_instruct +base: llama3_instruct +model_arch: llama3_2_3b \ No newline at end of file diff --git a/src/fairseq2/assets/download_manager.py b/src/fairseq2/assets/download_manager.py index ceb780b01..073841cf3 100644 --- a/src/fairseq2/assets/download_manager.py +++ b/src/fairseq2/assets/download_manager.py @@ -8,28 +8,26 @@ import os from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import ExitStack from hashlib import sha1 from pathlib import Path from shutil import rmtree from tarfile import TarFile, is_tarfile from tempfile import NamedTemporaryFile -from typing import Dict, Iterator, Optional, final +from typing import final from urllib.error import HTTPError, URLError from urllib.parse import unquote, urlparse from urllib.request import Request, urlopen from zipfile import BadZipFile, ZipFile from tqdm import tqdm # type: ignore[import] +from typing_extensions import override from fairseq2.assets.card import _starts_with_scheme -from fairseq2.assets.error import AssetError -from fairseq2.logging import get_log_writer -from fairseq2.typing import override +from fairseq2.logging import log from fairseq2.utils.env import get_path_from_env -log = get_log_writer(__name__) - class AssetDownloadManager(ABC): """Downloads assets.""" @@ -40,7 +38,7 @@ def download_checkpoint( uri: str, model_name: str, *, - shard_idx: Optional[int] = None, + shard_idx: int | None = None, force: bool = False, progress: bool = True, ) -> Path: @@ -67,7 +65,7 @@ def download_tokenizer( uri: str, model_name: str, *, - tokenizer_name: Optional[str] = None, + tokenizer_name: str | None = None, force: bool = False, progress: bool = True, ) -> Path: @@ -113,6 +111,10 @@ def download_dataset( """ +class AssetDownloadError(Exception): + pass + + @final class InProcAssetDownloadManager(AssetDownloadManager): """Downloads assets in this process.""" @@ -120,9 +122,9 @@ class InProcAssetDownloadManager(AssetDownloadManager): _cache_dir: Path def __init__(self) -> None: - cache_dir = get_path_from_env("FAIRSEQ2_CACHE_DIR", log, missing_ok=True) + cache_dir = get_path_from_env("FAIRSEQ2_CACHE_DIR", missing_ok=True) if cache_dir is None: - cache_dir = get_path_from_env("XDG_CACHE_HOME", log) + cache_dir = get_path_from_env("XDG_CACHE_HOME") if cache_dir is None: cache_dir = Path("~/.cache").expanduser() @@ -136,7 +138,7 @@ def download_checkpoint( uri: str, model_name: str, *, - shard_idx: Optional[int] = None, + shard_idx: int | None = None, force: bool = False, progress: bool = True, ) -> Path: @@ -157,7 +159,7 @@ def download_tokenizer( uri: str, model_name: str, *, - tokenizer_name: Optional[str] = None, + tokenizer_name: str | None = None, force: bool = False, progress: bool = True, ) -> Path: @@ -189,12 +191,12 @@ def download_dataset( class _AssetDownloadOp: _cache_dir: Path _uri: str - _uri_params: Dict[str, str] - _asset_dir: Optional[Path] + _uri_params: dict[str, str] + _asset_dir: Path | None _display_name: str _force: bool _progress: bool - _shard_idx: Optional[int] + _shard_idx: int | None def __init__( self, @@ -203,7 +205,7 @@ def __init__( display_name: str, force: bool, progress: bool, - shard_idx: Optional[int] = None, + shard_idx: int | None = None, ) -> None: self._cache_dir = cache_dir self._uri = uri @@ -223,7 +225,7 @@ def run(self) -> Path: if (asset_path := self._try_uri_as_path()) is not None: if not asset_path.exists(): - raise AssetError( + raise AssetDownloadError( f"The {self._display_name} cannot be found at {asset_path}." ) @@ -247,10 +249,10 @@ def _process_uri(self) -> None: uri = Path(uri).as_uri() # Normalize. parsed_uri = urlparse(uri) - except ValueError as ex: + except ValueError: raise ValueError( f"`uri` must be a URI or an absolute pathname, but is '{uri}' instead." - ) from ex + ) from None if parsed_uri.params: for param in parsed_uri.params.split(";"): @@ -278,7 +280,7 @@ def _format_uri_with_shard_index(self) -> None: sharded_uri = self._uri.replace("%7Bshard_idx%7D", str(self._shard_idx)) if sharded_uri == self._uri: - raise AssetError( + raise AssetDownloadError( f"`shard_idx` is specified, but the {self._display_name} is not sharded." ) @@ -286,11 +288,11 @@ def _format_uri_with_shard_index(self) -> None: def _check_if_gated_asset(self) -> None: if self._uri_params.get("gated", "false").strip().lower() == "true": - raise AssetError( + raise AssetDownloadError( f"The {self._display_name} is gated. Please visit {self._uri} to learn how to get access." ) - def _try_uri_as_path(self) -> Optional[Path]: + def _try_uri_as_path(self) -> Path | None: if self._uri.startswith("file://"): return Path(unquote(self._uri[7:])) @@ -316,7 +318,7 @@ def _prepare_op(self) -> None: rmtree(asset_dir) except OSError as ex: raise AssetDownloadError( - f"The asset cache directory of the {self._display_name} cannot be deleted. See nested exception for details." + f"The asset cache directory of the {self._display_name} cannot be deleted. See the nested exception for details." ) from ex download_dir = asset_dir.with_suffix(".download") @@ -325,7 +327,7 @@ def _prepare_op(self) -> None: rmtree(download_dir) except OSError as ex: raise AssetDownloadError( - f"The asset download directory of the {self._display_name} cannot be deleted. See nested exception for details." + f"The asset download directory of the {self._display_name} cannot be deleted. See the nested exception for details." ) from ex download_dir = asset_dir.with_suffix(".download.tmp") @@ -334,7 +336,7 @@ def _prepare_op(self) -> None: rmtree(download_dir) except OSError as ex: raise AssetDownloadError( - f"The asset download directory of the {self._display_name} cannot be deleted. See nested exception for details." + f"The asset download directory of the {self._display_name} cannot be deleted. See the nested exception for details." ) from ex else: if asset_dir.exists(): @@ -364,8 +366,8 @@ def _download_asset(self) -> None: try: tmp_dir.mkdir(parents=True, exist_ok=True) except OSError as ex: - raise AssetError( - f"The asset download directory of the {self._display_name} cannot be created. See nested exception for details." + raise AssetDownloadError( + f"The asset download directory of the {self._display_name} cannot be created. See the nested exception for details." ) from ex def remove_tmp_dir() -> None: @@ -392,7 +394,7 @@ def remove_tmp_dir() -> None: response = cleanup_stack.enter_context(urlopen(request)) except URLError as ex: raise AssetDownloadError( - f"The download of the {self._display_name} has failed. See nested exception for details." + f"The download of the {self._display_name} has failed. See the nested exception for details." ) from ex except HTTPError as ex: raise AssetDownloadError( @@ -464,15 +466,15 @@ def remove_tmp_dir() -> None: try: os.replace(fp.name, asset_file) except OSError: - raise AssetError( - f"The {self._display_name} cannot be saved to the asset download directory. See nested exception for details." + raise AssetDownloadError( + f"The {self._display_name} cannot be saved to the asset download directory. See the nested exception for details." ) try: tmp_dir.replace(download_dir) except OSError: - raise AssetError( - f"The asset download directory of the {self._display_name} cannot be renamed. See nested exception for details." + raise AssetDownloadError( + f"The asset download directory of the {self._display_name} cannot be renamed. See the nested exception for details." ) succeeded = True @@ -493,8 +495,8 @@ def _ensure_asset_extracted(self) -> None: try: asset_dir.mkdir(parents=True, exist_ok=True) except OSError as ex: - raise AssetError( - f"The asset cache directory of the {self._display_name} cannot be created. See nested exception for details." + raise AssetDownloadError( + f"The asset cache directory of the {self._display_name} cannot be created. See the nested exception for details." ) from ex def iter_dir() -> Iterator[Path]: @@ -502,8 +504,8 @@ def iter_dir() -> Iterator[Path]: for path in download_dir.iterdir(): yield path except OSError as ex: - raise AssetError( - f"The asset download directory of the {self._display_name} cannot be traversed. See nested exception for details." + raise AssetDownloadError( + f"The asset download directory of the {self._display_name} cannot be traversed. See the nested exception for details." ) from ex for asset_path in iter_dir(): @@ -517,8 +519,8 @@ def iter_dir() -> Iterator[Path]: with ZipFile(asset_path) as zip_fp: zip_fp.extractall(path=asset_dir) except (KeyError, OSError, BadZipFile) as ex: - raise AssetError( - f"The {self._display_name} cannot be extracted. See nested exception for details." + raise AssetDownloadError( + f"The {self._display_name} cannot be extracted. See the nested exception for details." ) from ex try: @@ -534,8 +536,8 @@ def iter_dir() -> Iterator[Path]: with TarFile(asset_path) as tar_fp: tar_fp.extractall(path=asset_dir) except (KeyError, OSError) as ex: - raise AssetError( - f"The {self._display_name} cannot be extracted. See nested exception for details." + raise AssetDownloadError( + f"The {self._display_name} cannot be extracted. See the nested exception for details." ) from ex try: @@ -548,15 +550,15 @@ def iter_dir() -> Iterator[Path]: try: asset_path.replace(asset_dir.joinpath(asset_path.name)) except OSError as ex: - raise AssetError( - f"The {self._display_name} cannot be moved to the asset cache directory. See nested exception for details." + raise AssetDownloadError( + f"The {self._display_name} cannot be moved to the asset cache directory. See the nested exception for details." ) from ex try: rmtree(download_dir) except OSError as ex: - raise AssetError( - f"The asset download directory of the {self._display_name} cannot be deleted. See nested exception for details." + raise AssetDownloadError( + f"The asset download directory of the {self._display_name} cannot be deleted. See the nested exception for details." ) from ex def _get_final_asset_path(self) -> Path: @@ -573,12 +575,12 @@ def _get_final_asset_path(self) -> Path: try: asset_path.relative_to(asset_dir) except ValueError as ex: - raise AssetError( + raise AssetDownloadError( f"The 'path' URI parameter of the {self._display_name} ({asset_pathname}) points to a path outside of the asset cache directory." ) from ex if not asset_path.exists(): - raise AssetError( + raise AssetDownloadError( f"The {self._display_name} cannot be found. Please set `force` to `True` and, if the problem persists, file a bug report." ) @@ -595,20 +597,13 @@ def _get_final_asset_path(self) -> Path: asset_path = path except OSError as ex: - raise AssetError( - f"The asset cache directory of the {self._display_name} cannot be traversed. See nested exception for details." + raise AssetDownloadError( + f"The asset cache directory of the {self._display_name} cannot be traversed. See the nested exception for details." ) from ex if asset_path is None: - raise AssetError( + raise AssetDownloadError( f"The asset cache directory of the {self._display_name} is empty. Please set `force` to `True` and, if the problem persists, file a bug report." ) return asset_path - - -class AssetDownloadError(AssetError): - """Raised when an asset download operation fails.""" - - -default_download_manager = InProcAssetDownloadManager() diff --git a/src/fairseq2/assets/error.py b/src/fairseq2/assets/error.py index f6bdc5853..ecebbd96a 100644 --- a/src/fairseq2/assets/error.py +++ b/src/fairseq2/assets/error.py @@ -6,6 +6,28 @@ from __future__ import annotations +from typing import TypeAlias -class AssetError(RuntimeError): - """Raised when an asset operation fails.""" + +class AssetError(Exception): + pass + + +class AssetCardError(AssetError): + name: str + + def __init__(self, name: str, message: str) -> None: + super().__init__(message) + + self.name = name + + +class AssetCardNotFoundError(AssetCardError): + pass + + +class AssetCardFieldNotFoundError(AssetCardError): + pass + + +AssetNotFoundError: TypeAlias = AssetCardNotFoundError # compat diff --git a/src/fairseq2/assets/metadata_provider.py b/src/fairseq2/assets/metadata_provider.py index 8a93a5398..7cab9eaf1 100644 --- a/src/fairseq2/assets/metadata_provider.py +++ b/src/fairseq2/assets/metadata_provider.py @@ -6,27 +6,26 @@ from __future__ import annotations -import os from abc import ABC, abstractmethod +from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, final +from typing import NoReturn, final -import yaml -from importlib_resources import files +from importlib_resources import files as get_files from importlib_resources.readers import MultiplexedPath -from typing_extensions import NoReturn -from yaml import YAMLError +from typing_extensions import override -from fairseq2.assets.error import AssetError -from fairseq2.typing import override +from fairseq2.error import ContractError, InternalError +from fairseq2.utils.file import FileSystem +from fairseq2.utils.yaml import YamlError, YamlLoader class AssetMetadataProvider(ABC): """Provides asset metadata.""" @abstractmethod - def get_metadata(self, name: str) -> Dict[str, Any]: + def get_metadata(self, name: str) -> dict[str, object]: """Return the metadata of the specified asset. :param name: @@ -34,7 +33,7 @@ def get_metadata(self, name: str) -> Dict[str, Any]: """ @abstractmethod - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: """Return the names of the assets for which this provider has metadata.""" @abstractmethod @@ -45,26 +44,37 @@ def clear_cache(self) -> None: class AbstractAssetMetadataProvider(AssetMetadataProvider): """Provides a skeletal implementation of :class:`AssetMetadataProvider`.""" - _cache: Optional[Dict[str, Dict[str, Any]]] + _cache: dict[str, dict[str, object]] | None def __init__(self) -> None: + """ + :param scope: + The scope of the provider. + """ self._cache = None @final @override - def get_metadata(self, name: str) -> Dict[str, Any]: + def get_metadata(self, name: str) -> dict[str, object]: cache = self._ensure_cache_loaded() try: - return deepcopy(cache[name]) + metadata = cache[name] except KeyError: - raise AssetNotFoundError( - name, f"An asset with the name '{name}' cannot be found." + raise AssetMetadataNotFoundError( + f"An asset metadata with name '{name}' is not found." ) from None + try: + return deepcopy(metadata) + except Exception as ex: + raise ContractError( + f"The metadata of the '{name}' asset cannot be copied. See the nested exception for details." + ) from ex + @final @override - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: cache = self._ensure_cache_loaded() return list(cache.keys()) @@ -74,7 +84,7 @@ def get_names(self) -> List[str]: def clear_cache(self) -> None: self._cache = None - def _ensure_cache_loaded(self) -> Dict[str, Dict[str, Any]]: + def _ensure_cache_loaded(self) -> dict[str, dict[str, object]]: if self._cache is not None: return self._cache @@ -83,7 +93,7 @@ def _ensure_cache_loaded(self) -> Dict[str, Dict[str, Any]]: return self._cache @abstractmethod - def _load_cache(self) -> Dict[str, Dict[str, Any]]: + def _load_cache(self) -> dict[str, dict[str, object]]: ... @@ -91,46 +101,63 @@ def _load_cache(self) -> Dict[str, Dict[str, Any]]: class FileAssetMetadataProvider(AbstractAssetMetadataProvider): """Provides asset metadata stored on a file system.""" - _base_dir: Path + _path: Path + _file_system: FileSystem + _yaml_loader: YamlLoader - def __init__(self, base_dir: Path) -> None: - """ - :param base_dir: - The base directory under which the asset metadata is stored. - """ + def __init__( + self, path: Path, file_system: FileSystem, yaml_loader: YamlLoader + ) -> None: super().__init__() - self._base_dir = base_dir.expanduser().resolve() - - self._cache = None + self._path = path + self._file_system = file_system + self._yaml_loader = yaml_loader @override - def _load_cache(self) -> Dict[str, Dict[str, Any]]: - def on_error(ex: OSError) -> NoReturn: - raise AssetMetadataError( - f"The base asset metadata directory '{self._base_dir}' cannot be traversed. See nested exception for details." - ) from ex + def _load_cache(self) -> dict[str, dict[str, object]]: + path = self._file_system.resolve(self._path) cache = {} - for dir_pathname, _, filenames in os.walk(self._base_dir, onerror=on_error): - metadata_dir = Path(dir_pathname) + def cache_file(file: Path, source: str) -> None: + for name, metadata in load_metadata_file(file, self._yaml_loader): + if name in cache: + if file == path: + raise AssetMetadataError( + f"Two assets in the '{path}' file have the same name '{name}'." + ) + else: + raise AssetMetadataError( + f"Two assets under the '{path}' directory have the same name '{name}'." + ) - for filename in filenames: - file = metadata_dir.joinpath(filename) + metadata["__source__"] = source - if file.suffix != ".yaml" and file.suffix != ".yml": - continue + cache[name] = metadata - for name, metadata in load_metadata_file(file): - if name in cache: - raise AssetMetadataError( - f"Two assets under the directory '{self._base_dir}' have the same name '{name}'." - ) + if path.is_dir(): + source = f"directory:{path}" + + def on_error(ex: OSError) -> NoReturn: + raise AssetMetadataError( + f"The '{path}' base asset metadata directory cannot be traversed. See the nested exception for details." + ) from ex + + for dir_pathname, filenames in self._file_system.walk_directory( + path, on_error=on_error + ): + metadata_dir = Path(dir_pathname) - metadata["__source__"] = f"directory:{self._base_dir}" + for filename in filenames: + file = metadata_dir.joinpath(filename) - cache[name] = metadata + if file.suffix != ".yaml" and file.suffix != ".yml": + continue + + cache_file(file, source) + else: + cache_file(path, source=f"file:{path}") return cache @@ -140,47 +167,61 @@ class PackageAssetMetadataProvider(AbstractAssetMetadataProvider): """Provides asset metadata stored in a Python namespace package.""" _package_name: str - _package_path: MultiplexedPath + _package_file_lister: PackageFileLister + _yaml_loader: YamlLoader - def __init__(self, package_name: str) -> None: - """ - :param package_name: - The name of the package in which the asset metadata is stored. - """ + def __init__( + self, + package_name: str, + package_file_lister: PackageFileLister, + yaml_loader: YamlLoader, + ) -> None: super().__init__() self._package_name = package_name - - self._package_path = files(package_name) + self._package_file_lister = package_file_lister + self._yaml_loader = yaml_loader @override - def _load_cache(self) -> Dict[str, Dict[str, Any]]: + def _load_cache(self) -> dict[str, dict[str, object]]: + source = f"package:{self._package_name}" + cache = {} - for file in self._list_files(): + for file in self._package_file_lister.list(self._package_name): if file.suffix != ".yaml" and file.suffix != ".yml": continue - for name, metadata in load_metadata_file(file): + for name, metadata in load_metadata_file(file, self._yaml_loader): if name in cache: raise AssetMetadataError( - f"Two assets under the namespace package '{self._package_name}' have the same name '{name}'." + f"Two assets in the '{self._package_name}' package have the same name '{name}'." ) - metadata["__source__"] = f"package:{self._package_name}" + metadata["__source__"] = source cache[name] = metadata return cache - def _list_files(self) -> List[Path]: + +class PackageFileLister(ABC): + @abstractmethod + def list(self, package_name: str) -> list[Path]: + ... + + +@final +class WheelPackageFileLister(PackageFileLister): + @override + def list(self, package_name: str) -> list[Path]: files = [] - def collect_files(p: Union[MultiplexedPath, Path]) -> None: + def collect_files(p: MultiplexedPath | Path) -> None: if p.is_file(): if not isinstance(p, Path): - raise RuntimeError( - "`importlib.resources` returned a file path that is not of type `pathlib.Path`. Please file a bug report." + raise InternalError( + f"`importlib.resources` returned a path of type `{type(p)}`." ) files.append(p) @@ -188,53 +229,56 @@ def collect_files(p: Union[MultiplexedPath, Path]) -> None: for e in p.iterdir(): collect_files(e) - collect_files(self._package_path) + path = get_files(package_name) + + collect_files(path) return files -def load_metadata_file(file: Path) -> List[Tuple[str, Dict[str, Any]]]: +def load_metadata_file( + file: Path, yaml_loader: YamlLoader +) -> list[tuple[str, dict[str, object]]]: """Load asset metadata included in ``file``.""" output = [] try: - fp = file.open() - except OSError as ex: + all_metadata = yaml_loader(file) + except (OSError, YamlError) as ex: raise AssetMetadataError( - f"The asset metadata file '{file}' cannot be opened. See nested exception for details." + f"The '{file}' asset metadata file cannot be loaded as YAML. See the nested exception for details." ) from ex - with fp: + for idx, metadata in enumerate(all_metadata): + if not isinstance(metadata, dict): + raise AssetMetadataError( + f"The asset metadata at index {idx} in the '{file}' file is expected to be of type `dict`, but is of type `{type(metadata)}` instead." + ) + + try: + name = metadata.pop("name") + except KeyError: + raise AssetMetadataError( + f"The asset metadata at index {idx} in the '{file}' file does not have a name." + ) from None + try: - all_metadata = yaml.safe_load_all(fp) - except (OSError, YAMLError) as ex: + canonical_name = _canonicalize_name(name) + except ValueError as ex: raise AssetMetadataError( - f"The asset metadata file '{file}' cannot be loaded. See nested exception for details." + f"The asset metadata at index {idx} in the '{file}' file does not have a valid name. See the nested exception for details." ) from ex - for idx, metadata in enumerate(all_metadata): - if not isinstance(metadata, dict): + base = metadata.get("base") + if base is not None: + if not isinstance(base, str) or "@" in base: raise AssetMetadataError( - f"The asset metadata at index {idx} in {file} has an invalid format." + f"The asset metadata at index {idx} in the '{file}' file does not have a valid base name." ) - try: - name = metadata.pop("name") - except KeyError: - raise AssetMetadataError( - f"The asset metadata at index {idx} in {file} does not have a name entry." - ) from None - - try: - canonical_name = _canonicalize_name(name) - except ValueError as ex: - raise AssetMetadataError( - f"The asset metadata at index {idx} in {file} has an invalid name. See nested exception for details." - ) from ex - - metadata["__base_path__"] = file.parent + metadata["__base_path__"] = file.parent - output.append((canonical_name, metadata)) + output.append((canonical_name, metadata)) return output @@ -243,33 +287,27 @@ def load_metadata_file(file: Path) -> List[Tuple[str, Dict[str, Any]]]: class InProcAssetMetadataProvider(AssetMetadataProvider): """Provides asset metadata stored in memory.""" - _name: Optional[str] - _metadata: Dict[str, Dict[str, Any]] - - def __init__( - self, metadata: Sequence[Dict[str, Any]], *, name: Optional[str] = None - ) -> None: - self._name = name - self._metadata = {} + _metadata: dict[str, dict[str, object]] + _scope: str - source = "inproc" + def __init__(self, metadata: Sequence[dict[str, object]]) -> None: + super().__init__() - if name is not None: - source = f"{source}:{name}" + self._metadata = {} for idx, metadata_ in enumerate(metadata): try: - name_ = metadata_.pop("name") + name = metadata_.pop("name") except KeyError: raise AssetMetadataError( - f"The asset metadata at index {idx} in `metadata` does not have a name entry." + f"The asset metadata at index {idx} in `metadata` does not have a name." ) from None try: - canonical_name = _canonicalize_name(name_) + canonical_name = _canonicalize_name(name) except ValueError as ex: raise AssetMetadataError( - f"The asset metadata at index {idx} in `metadata` has an invalid name. See nested exception for details." + f"The asset metadata at index {idx} in `metadata` does not have a valid name. See the nested exception for details." ) from ex if canonical_name in self._metadata: @@ -277,21 +315,28 @@ def __init__( f"Two assets in `metadata` have the same name '{canonical_name}'." ) - metadata_["__source__"] = source + base = metadata_.get("base") + if base is not None: + if not isinstance(base, str) or "@" in base: + raise AssetMetadataError( + f"The asset metadata at index {idx} in `metadata` file does not have a valid base name." + ) + + metadata_["__source__"] = "inproc" self._metadata[canonical_name] = metadata_ @override - def get_metadata(self, name: str) -> Dict[str, Any]: + def get_metadata(self, name: str) -> dict[str, object]: try: return deepcopy(self._metadata[name]) except KeyError: - raise AssetNotFoundError( - name, f"An asset with the name '{name}' cannot be found." + raise AssetMetadataNotFoundError( + f"An asset metadata with name '{name}' is not found." ) from None @override - def get_names(self) -> List[str]: + def get_names(self) -> list[str]: return list(self._metadata.keys()) @override @@ -299,10 +344,18 @@ def clear_cache(self) -> None: pass -def _canonicalize_name(name: Any) -> str: +class AssetMetadataError(Exception): + pass + + +class AssetMetadataNotFoundError(AssetMetadataError): + pass + + +def _canonicalize_name(name: object) -> str: if not isinstance(name, str): raise ValueError( - f"`name` must be of type `{str}`, but is of type `{type(name)}` instead." + f"`name` must be of type `str`, but is of type `{type(name)}` instead." ) name_env_pair = name.split("@") @@ -316,23 +369,3 @@ def _canonicalize_name(name: Any) -> str: name_env_pair.append("") # empty env return "@".join(name_env_pair) - - -class AssetNotFoundError(AssetError): - """Raised when an asset cannot be found.""" - - _name: str - - def __init__(self, name: str, msg: str) -> None: - super().__init__(msg) - - self._name = name - - @property - def name(self) -> str: - """The name of the asset.""" - return self._name - - -class AssetMetadataError(AssetError): - """Raised when an asset metadata operation fails.""" diff --git a/src/fairseq2/assets/store.py b/src/fairseq2/assets/store.py index 30cdc7658..225d58edb 100644 --- a/src/fairseq2/assets/store.py +++ b/src/fairseq2/assets/store.py @@ -7,21 +7,25 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Protocol, Sequence, final +from typing import Literal, Protocol, TypeAlias, final -from fairseq2.assets.card import AssetCard, AssetCardError +from typing_extensions import override + +from fairseq2.assets.card import AssetCard +from fairseq2.assets.error import AssetCardError, AssetCardNotFoundError from fairseq2.assets.metadata_provider import ( + AssetMetadataNotFoundError, AssetMetadataProvider, - AssetNotFoundError, - FileAssetMetadataProvider, PackageAssetMetadataProvider, + WheelPackageFileLister, ) -from fairseq2.logging import get_log_writer -from fairseq2.typing import override +from fairseq2.error import ContractError from fairseq2.utils.env import get_path_from_env +from fairseq2.utils.yaml import load_yaml -log = get_log_writer(__name__) +AssetScope: TypeAlias = Literal["all", "global", "user"] class AssetStore(ABC): @@ -29,11 +33,7 @@ class AssetStore(ABC): @abstractmethod def retrieve_card( - self, - name: str, - *, - envs: Optional[Sequence[str]] = None, - scope: Literal["all", "global", "user"] = "all", + self, name: str, *, envs: Sequence[str] | None = None, scope: AssetScope = "all" ) -> AssetCard: """Retrieve the card of the specified asset. @@ -48,9 +48,7 @@ def retrieve_card( """ @abstractmethod - def retrieve_names( - self, *, scope: Literal["all", "global", "user"] = "all" - ) -> List[str]: + def retrieve_names(self, *, scope: AssetScope = "all") -> list[str]: """Retrieve the names of the assets contained in this store. :param scope: @@ -62,28 +60,24 @@ def retrieve_names( class StandardAssetStore(AssetStore): """Represents a store of assets.""" - env_resolvers: List[EnvironmentResolver] - metadata_providers: List[AssetMetadataProvider] - user_metadata_providers: List[AssetMetadataProvider] + env_resolvers: list[EnvironmentResolver] + metadata_providers: list[AssetMetadataProvider] + user_metadata_providers: list[AssetMetadataProvider] - def __init__(self, metadata_provider: AssetMetadataProvider) -> None: - """ - :param storage: - The default asset metadata provider. - """ + def __init__(self) -> None: self.env_resolvers = [] - self.metadata_providers = [metadata_provider] + self.metadata_providers = [] self.user_metadata_providers = [] @override def retrieve_card( - self, - name: str, - *, - envs: Optional[Sequence[str]] = None, - scope: Literal["all", "global", "user"] = "all", - extra_provider: Optional[AssetMetadataProvider] = None, + self, name: str, *, envs: Sequence[str] | None = None, scope: AssetScope = "all" ) -> AssetCard: + if scope not in ("all", "global", "user"): + raise ValueError( + f"`scope` must be 'all', 'global', or 'user', but is '{scope}' instead." + ) + name_env_pair = name.split("@", maxsplit=1) name = name_env_pair[0] @@ -92,7 +86,7 @@ def retrieve_card( if len(name_env_pair) == 2: if envs is not None: raise ValueError( - "`name` already contains an environment tag, `envs` must be `None`." + "`envs` must be `None` since `name` already contains an environment tag." ) envs = [name_env_pair[1]] @@ -100,9 +94,9 @@ def retrieve_card( if envs is None: envs = self._resolve_envs() - return self._do_retrieve_card(name, envs, scope, extra_provider) + return self._do_retrieve_card(name, envs, scope) - def _resolve_envs(self) -> List[str]: + def _resolve_envs(self) -> list[str]: # This is a special, always available environment for users to override # asset metadata. For instance, a user can set the checkpoint path of a # gated model locally by having a same-named asset with a @user suffix. @@ -115,20 +109,19 @@ def _resolve_envs(self) -> List[str]: return envs def _do_retrieve_card( - self, - name: str, - envs: Sequence[str], - scope: str, - extra_provider: Optional[AssetMetadataProvider], + self, name: str, envs: Sequence[str], scope: str ) -> AssetCard: - metadata = self._get_metadata(f"{name}@", scope, extra_provider) + try: + metadata = self._get_metadata(f"{name}@", scope) + except AssetMetadataNotFoundError: + raise AssetCardNotFoundError( + name, f"An asset card with name '{name}' is not found." + ) from None # If we have environment-specific metadata, merge it with `metadata`. for env in reversed(envs): try: - env_metadata = self._get_metadata( - f"{name}@{env}", scope, extra_provider - ) + env_metadata = self._get_metadata(f"{name}@{env}", scope) # Do not allow overriding 'name'. try: @@ -137,64 +130,70 @@ def _do_retrieve_card( pass metadata.update(env_metadata) - except AssetNotFoundError: + except AssetMetadataNotFoundError: pass - try: - base_name = metadata["base"] - except KeyError: - base_name = None + def contract_error( + field: str, value: object, expected_kls: object + ) -> ContractError: + return ContractError( + f"The value of the '{field}' field of the '{name}' asset card is expected to be of type `{expected_kls}`, but is of type `{type(value)}` instead." + ) + + base_name = metadata.get("base") - base_card: Optional[AssetCard] = None + base_card: AssetCard | None = None # If the metadata has a base specified, we have to recursively load the # entire chain up to the root. - if base_name: + if base_name is not None: if not isinstance(base_name, str): + raise contract_error("base", base_name, "str") + + try: + base_card = self._do_retrieve_card(base_name, envs, scope) + except AssetCardNotFoundError: raise AssetCardError( - f"The value of the field 'base' of the asset card '{name}' must be of type `{str}`, but is of type `{type(base_name)}` instead." - ) + name, f"A transitive base asset card with name '{base_name}' is not found." # fmt: skip + ) from None - base_card = self._do_retrieve_card(base_name, envs, scope, extra_provider) + base_path = metadata.get("__base_path__") + if base_path is not None and not isinstance(base_path, Path): + raise contract_error("__base_path__", base_path, Path) metadata["name"] = name - return AssetCard(metadata, base_card) - - def _get_metadata( - self, name: str, scope: str, extra_provider: Optional[AssetMetadataProvider] - ) -> Dict[str, Any]: - if extra_provider is not None: - try: - return extra_provider.get_metadata(name) - except AssetNotFoundError: - pass + return AssetCard(name, metadata, base_card, base_path) + def _get_metadata(self, name: str, scope: str) -> dict[str, object]: if scope == "all" or scope == "user": for provider in reversed(self.user_metadata_providers): try: return provider.get_metadata(name) - except AssetNotFoundError: + except AssetMetadataNotFoundError: continue if scope == "all" or scope == "global": for provider in reversed(self.metadata_providers): try: return provider.get_metadata(name) - except AssetNotFoundError: + except AssetMetadataNotFoundError: continue if name[-1] == "@": name = name[:-1] - raise AssetNotFoundError( - name, f"An asset with the name '{name}' cannot be found. Run `fairseq2 assets list` to see the list of available assets." # fmt: skip - ) + raise AssetMetadataNotFoundError( + f"An asset metadata with name '{name}' is not found." + ) from None @override - def retrieve_names( - self, *, scope: Literal["all", "global", "user"] = "all" - ) -> List[str]: + def retrieve_names(self, *, scope: AssetScope = "all") -> list[str]: + if scope not in ("all", "global", "user"): + raise ValueError( + f"`scope` must be 'all', 'global', or 'user', but is '{scope}' instead." + ) + names = [] if scope == "all" or scope == "user": @@ -215,72 +214,48 @@ def clear_cache(self) -> None: for provider in self.user_metadata_providers: provider.clear_cache() - def add_file_metadata_provider(self, path: Path, user: bool = False) -> None: - """Add a new :class:`FileAssetMetadataProvider` pointing to ``path``. - - :param path: - The directory under which asset metadata is stored. - :param user: - If ``True``, adds the metadata provider to the user scope. - """ - providers = self.user_metadata_providers if user else self.metadata_providers - - providers.append(FileAssetMetadataProvider(path)) - def add_package_metadata_provider(self, package_name: str) -> None: """Add a new :class:`PackageAssetMetadataProvider` for ``package_name``. - :param package_name: - The name of the package in which asset metadata is stored. + :param package_name: The name of the package in which asset metadata is + stored. """ - self.metadata_providers.append(PackageAssetMetadataProvider(package_name)) - - -class EnvironmentResolver(Protocol): - """Resolves the environment within which assets should be loaded. + file_lister = WheelPackageFileLister() - Assets can have varying metadata depending on the environment that they are - loaded in due to legal or technical requirements. - """ + provider = PackageAssetMetadataProvider(package_name, file_lister, load_yaml) - def __call__(self) -> Optional[str]: - ... + self.metadata_providers.append(provider) -def _create_default_asset_store() -> StandardAssetStore: - metadata_provider = PackageAssetMetadataProvider("fairseq2.assets.cards") +class EnvironmentResolver(Protocol): + """Resolves the environment within which assets should be loaded.""" - return StandardAssetStore(metadata_provider) + def __call__(self) -> str | None: + ... -default_asset_store = _create_default_asset_store() +default_asset_store = StandardAssetStore() -def _load_asset_directory() -> None: - asset_dir = get_path_from_env("FAIRSEQ2_ASSET_DIR", log) +def get_asset_dir() -> Path | None: + asset_dir = get_path_from_env("FAIRSEQ2_ASSET_DIR") if asset_dir is None: asset_dir = Path("/etc/fairseq2/assets").resolve() if not asset_dir.exists(): - return - - default_asset_store.add_file_metadata_provider(asset_dir) + return None + return asset_dir -_load_asset_directory() - -def _load_user_asset_directory() -> None: - asset_dir = get_path_from_env("FAIRSEQ2_USER_ASSET_DIR", log) +def get_user_asset_dir() -> Path | None: + asset_dir = get_path_from_env("FAIRSEQ2_USER_ASSET_DIR") if asset_dir is None: - asset_dir = get_path_from_env("XDG_CONFIG_HOME", log) + asset_dir = get_path_from_env("XDG_CONFIG_HOME") if asset_dir is None: asset_dir = Path("~/.config").expanduser() asset_dir = asset_dir.joinpath("fairseq2/assets").resolve() if not asset_dir.exists(): - return - - default_asset_store.add_file_metadata_provider(asset_dir, user=True) - + return None -_load_user_asset_directory() + return asset_dir diff --git a/src/fairseq2/chatbots/__init__.py b/src/fairseq2/chatbots/__init__.py new file mode 100644 index 000000000..b44133ed2 --- /dev/null +++ b/src/fairseq2/chatbots/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.chatbots.chatbot import AbstractChatbot as AbstractChatbot +from fairseq2.chatbots.chatbot import Chatbot as Chatbot +from fairseq2.chatbots.chatbot import ChatDialog as ChatDialog +from fairseq2.chatbots.chatbot import ChatMessage as ChatMessage +from fairseq2.chatbots.handler import ChatbotHandler as ChatbotHandler +from fairseq2.chatbots.handler import ChatbotNotFoundError as ChatbotNotFoundError +from fairseq2.chatbots.static import create_chatbot as create_chatbot diff --git a/src/fairseq2/generation/chatbot.py b/src/fairseq2/chatbots/chatbot.py similarity index 73% rename from src/fairseq2/generation/chatbot.py rename to src/fairseq2/chatbots/chatbot.py index 00c7cbeb0..245cd6117 100644 --- a/src/fairseq2/generation/chatbot.py +++ b/src/fairseq2/chatbots/chatbot.py @@ -7,18 +7,17 @@ from __future__ import annotations from abc import ABC, abstractmethod -from contextlib import nullcontext +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, ContextManager, List, Literal, Optional, Sequence, Tuple, final +from typing import Literal, TypeAlias, final from torch import Tensor -from typing_extensions import TypeAlias +from typing_extensions import override from fairseq2.data.text import TextTokenDecoder, TextTokenizer +from fairseq2.error import ContractError from fairseq2.generation.generator import SequenceGenerator, SequenceGeneratorOutput -from fairseq2.generation.utils import _StdOutPrintHook from fairseq2.nn.padding import PaddingMask, pad_seqs -from fairseq2.typing import override @final @@ -41,13 +40,11 @@ class Chatbot(ABC): @abstractmethod def __call__( - self, dialog: ChatDialog, *, stdout: bool = False - ) -> Tuple[ChatMessage, SequenceGeneratorOutput]: + self, dialog: ChatDialog + ) -> tuple[ChatMessage, SequenceGeneratorOutput]: """ :param dialog: The chat dialog that the bot should respond to. - :param stdout: - If ``True``, prints the generated message in real-time to stdout. :returns: - The response message of the bot. @@ -57,7 +54,7 @@ def __call__( @abstractmethod def batch_response( self, dialogs: Sequence[ChatDialog] - ) -> Tuple[List[ChatMessage], SequenceGeneratorOutput]: + ) -> tuple[list[ChatMessage], SequenceGeneratorOutput]: """ :param dialogs: The chat dialogs that the bot should respond to. @@ -93,23 +90,13 @@ def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> No @final @override def __call__( - self, dialog: ChatDialog, *, stdout: bool = False - ) -> Tuple[ChatMessage, SequenceGeneratorOutput]: + self, dialog: ChatDialog + ) -> tuple[ChatMessage, SequenceGeneratorOutput]: dialog_seq = self._encode_dialog(dialog, "dialog") - cm: ContextManager[Any] - - if stdout: - hook = _StdOutPrintHook(self._text_decoder) - - cm = self._generator.register_step_hook(hook) - else: - cm = nullcontext() - - with cm: - responses, generator_output = self.__do_response( - dialog_seq.unsqueeze(0), dialog_padding_mask=None - ) + responses, generator_output = self.__do_response( + dialog_seq.unsqueeze(0), dialog_padding_mask=None + ) return responses[0], generator_output @@ -117,7 +104,7 @@ def __call__( @override def batch_response( self, dialogs: Sequence[ChatDialog] - ) -> Tuple[List[ChatMessage], SequenceGeneratorOutput]: + ) -> tuple[list[ChatMessage], SequenceGeneratorOutput]: """ :param dialogs: The chat dialogs that the bot should respond to. @@ -135,16 +122,16 @@ def batch_response( return self.__do_response(dialog_seqs, dialog_padding_mask) def __do_response( - self, dialog_seqs: Tensor, dialog_padding_mask: Optional[PaddingMask] - ) -> Tuple[List[ChatMessage], SequenceGeneratorOutput]: + self, dialog_seqs: Tensor, dialog_padding_mask: PaddingMask | None + ) -> tuple[list[ChatMessage], SequenceGeneratorOutput]: generator_output = self._generator(dialog_seqs, dialog_padding_mask) - responses: List[ChatMessage] = [] + responses: list[ChatMessage] = [] for idx, hypotheses in enumerate(generator_output.hypotheses): if len(hypotheses) == 0: - raise RuntimeError( - f"The sequence generator returned no hypothesis at index {idx}. Please file a bug report." + raise ContractError( + f"The sequence generator returned no hypothesis at index {idx}." ) response = ChatMessage( diff --git a/src/fairseq2/chatbots/handler.py b/src/fairseq2/chatbots/handler.py new file mode 100644 index 000000000..6d425de99 --- /dev/null +++ b/src/fairseq2/chatbots/handler.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from fairseq2.chatbots.chatbot import Chatbot +from fairseq2.data.text import TextTokenizer +from fairseq2.generation.generator import SequenceGenerator + + +class ChatbotHandler(ABC): + @abstractmethod + def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: + ... + + +class ChatbotNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known chatbot.") + + self.name = name diff --git a/src/fairseq2/models/llama/chatbot.py b/src/fairseq2/chatbots/llama.py similarity index 68% rename from src/fairseq2/models/llama/chatbot.py rename to src/fairseq2/chatbots/llama.py index e3d9f754d..8e14f1591 100644 --- a/src/fairseq2/models/llama/chatbot.py +++ b/src/fairseq2/chatbots/llama.py @@ -6,24 +6,18 @@ from __future__ import annotations -from typing import List, final +from typing import final import torch from torch import Tensor +from typing_extensions import override +from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog, ChatMessage +from fairseq2.chatbots.handler import ChatbotHandler from fairseq2.data.text import TextTokenEncoder, TextTokenizer -from fairseq2.generation import ( - AbstractChatbot, - Chatbot, - ChatDialog, - ChatMessage, - SequenceGenerator, -) -from fairseq2.models.chatbot import create_chatbot -from fairseq2.models.llama.factory import LLAMA_FAMILY -from fairseq2.models.llama.tokenizer import LLaMA3Tokenizer +from fairseq2.data.text.tokenizers.llama import LLaMA3Tokenizer +from fairseq2.generation import SequenceGenerator from fairseq2.nn.utils.module import infer_device -from fairseq2.typing import override @final @@ -47,11 +41,14 @@ def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> No eos_idx = tokenizer.vocab_info.eos_idx if bos_idx is None or eos_idx is None: - raise RuntimeError( - "One or more required control symbols requierd for the chatbot are not found in the tokenizer. Please make sure that you are using the right tokenizer." - ) + raise ValueError("`tokenizer` must have BOS and EOS symbols defined.") - device = infer_device(generator.model, name="generator.model") + try: + device = infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex self._bos_idx = torch.tensor([bos_idx], device=device) self._eos_idx = torch.tensor([eos_idx], device=device) @@ -62,12 +59,12 @@ def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> No def _encode_dialog(self, dialog: ChatDialog, param_name: str) -> Tensor: if len(dialog) == 0: raise ValueError( - f"`{param_name}` must have at least one message with the role 'user'." + f"`{param_name}` must have at least one message with the 'user' role." ) if dialog[-1].role != "user": raise ValueError( - f"The last message of `{param_name}` must have the role 'user'." + f"The last message of `{param_name}` must have the 'user' role." ) # Merge the system message, if any, with the first user message. @@ -78,12 +75,12 @@ def _encode_dialog(self, dialog: ChatDialog, param_name: str) -> Tensor: dialog = [first_message] + list(dialog[2:]) - dialog_contents: List[Tensor] = [] + dialog_contents: list[Tensor] = [] for user, bot in zip(dialog[::2], dialog[1::2]): if user.role != "user" or bot.role != "bot": raise ValueError( - f"The messages of `{param_name}` might optionally start with the role 'system', and then must alternate between the roles 'user' and 'bot'." + f"The messages of `{param_name}` might optionally start with the 'system' role, and then must alternate between the 'user' and 'bot' roles." ) user_bot_seq = self._text_encoder( @@ -109,15 +106,13 @@ class LLaMA3Chatbot(AbstractChatbot): """Represents a LLaMA 3 chatbot.""" _bos_idx: Tensor + _eos_idx: Tensor _boh_idx: Tensor _eoh_idx: Tensor - _eot_idx: Tensor _text_encoder: TextTokenEncoder _break: Tensor - def __init__( - self, generator: SequenceGenerator, tokenizer: LLaMA3Tokenizer - ) -> None: + def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> None: """ :param generator: The sequence generator. @@ -126,22 +121,27 @@ def __init__( """ super().__init__(generator, tokenizer) - device = infer_device(generator.model, name="generator.model") + bos_idx = tokenizer.vocab_info.bos_idx + eos_idx = tokenizer.vocab_info.eos_idx + boh_idx = tokenizer.vocab_info.boh_idx + eoh_idx = tokenizer.vocab_info.eoh_idx + + if bos_idx is None or eos_idx is None or boh_idx is None or eoh_idx is None: + raise ValueError( + "`tokenizer` must have BOS, EOS, BOH, EOH symbols defined." + ) try: - bos_idx = tokenizer.encoding.encode_single_token("<|begin_of_text|>") - boh_idx = tokenizer.encoding.encode_single_token("<|start_header_id|>") - eoh_idx = tokenizer.encoding.encode_single_token("<|end_header_id|>") - eot_idx = tokenizer.encoding.encode_single_token("<|eot_id|>") - except KeyError: - raise RuntimeError( - "One or more special symbols required for the chatbot are not found in the tokenizer. Please file a bug report to the model author." - ) from None + device = infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex self._bos_idx = torch.tensor([bos_idx], device=device) + self._eos_idx = torch.tensor([eos_idx], device=device) self._boh_idx = torch.tensor([boh_idx], device=device) self._eoh_idx = torch.tensor([eoh_idx], device=device) - self._eot_idx = torch.tensor([eot_idx], device=device) self._text_encoder = tokenizer.create_raw_encoder(device=device) @@ -151,15 +151,15 @@ def __init__( def _encode_dialog(self, dialog: ChatDialog, param_name: str) -> Tensor: if len(dialog) == 0: raise ValueError( - f"`{param_name}` must have at least one message with the role 'user'." + f"`{param_name}` must have at least one message with the 'user' role." ) if dialog[-1].role != "user": raise ValueError( - f"The last message of `{param_name}` must have the role 'user'." + f"The last message of `{param_name}` must have the 'user' role." ) - dialog_contents: List[Tensor] = [self._bos_idx] + dialog_contents: list[Tensor] = [self._bos_idx] def encode_role(role: str) -> None: seq = self._text_encoder(role) @@ -169,7 +169,7 @@ def encode_role(role: str) -> None: def encode_content(content: str) -> None: seq = self._text_encoder(content.strip()) - dialog_contents.extend([seq, self._eot_idx]) + dialog_contents.extend([seq, self._eos_idx]) if dialog[0].role == "system": encode_role("system") @@ -181,7 +181,7 @@ def encode_content(content: str) -> None: for user, bot in zip(dialog[::2], dialog[1::2]): if user.role != "user" or bot.role != "bot": raise ValueError( - f"The messages of `{param_name}` might optionally start with the role 'system', and then must alternate between the roles 'user' and 'bot'." + f"The messages of `{param_name}` might optionally start with the 'system' role, and then must alternate between the 'user' and 'bot' roles." ) encode_role("user") @@ -206,14 +206,11 @@ def supports_system_prompt(self) -> bool: return True -def create_llama_chatbot( - generator: SequenceGenerator, tokenizer: TextTokenizer -) -> Chatbot: - """Create the appropriate LLaMA chatbot based on ``tokenizer``.""" - if isinstance(tokenizer, LLaMA3Tokenizer): - return LLaMA3Chatbot(generator, tokenizer) - - return LLaMAChatbot(generator, tokenizer) - +@final +class LLaMAChatbotHandler(ChatbotHandler): + @override + def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: + if isinstance(tokenizer, LLaMA3Tokenizer): + return LLaMA3Chatbot(generator, tokenizer) -create_chatbot.register(LLAMA_FAMILY, create_llama_chatbot) + return LLaMAChatbot(generator, tokenizer) diff --git a/src/fairseq2/models/mistral/chatbot.py b/src/fairseq2/chatbots/mistral.py similarity index 71% rename from src/fairseq2/models/mistral/chatbot.py rename to src/fairseq2/chatbots/mistral.py index dd38e2941..ba86f8d30 100644 --- a/src/fairseq2/models/mistral/chatbot.py +++ b/src/fairseq2/chatbots/mistral.py @@ -6,17 +6,17 @@ from __future__ import annotations -from typing import List, final +from typing import final import torch from torch import Tensor +from typing_extensions import override +from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog +from fairseq2.chatbots.handler import ChatbotHandler from fairseq2.data.text import TextTokenEncoder, TextTokenizer -from fairseq2.generation import AbstractChatbot, ChatDialog, SequenceGenerator -from fairseq2.models.chatbot import create_chatbot -from fairseq2.models.mistral.factory import MISTRAL_FAMILY +from fairseq2.generation import SequenceGenerator from fairseq2.nn.utils.module import infer_device -from fairseq2.typing import override @final @@ -40,11 +40,14 @@ def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> No eos_idx = tokenizer.vocab_info.eos_idx if bos_idx is None or eos_idx is None: - raise RuntimeError( - "One or more required control symbols requierd for the chatbot are not found in the tokenizer. Please make sure that you are using the right tokenizer." - ) + raise ValueError("`tokenizer` must have BOS and EOS symbols defined.") - device = infer_device(generator.model, name="generator.model") + try: + device = infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex self._bos_idx = torch.tensor([bos_idx], device=device) self._eos_idx = torch.tensor([eos_idx], device=device) @@ -55,20 +58,20 @@ def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> No def _encode_dialog(self, dialog: ChatDialog, param_name: str) -> Tensor: if len(dialog) == 0: raise ValueError( - f"`{param_name}` must have at least one message with the role 'user'." + f"`{param_name}` must have at least one message with the 'user' role." ) if dialog[-1].role != "user": raise ValueError( - f"The last message of `{param_name}` must have the role 'user'." + f"The last message of `{param_name}` must have the 'user' role." ) - dialog_contents: List[Tensor] = [self._bos_idx] + dialog_contents: list[Tensor] = [self._bos_idx] for user, bot in zip(dialog[::2], dialog[1::2]): if user.role != "user" or bot.role != "bot": raise ValueError( - f"The messages of `{param_name}` must alternate between the roles 'user' and 'bot'." + f"The messages of `{param_name}` must alternate between the 'user' and 'bot' roles." ) user_bot_seq = self._text_encoder( @@ -89,4 +92,8 @@ def supports_system_prompt(self) -> bool: return False -create_chatbot.register(MISTRAL_FAMILY, MistralChatbot) +@final +class MistralChatbotHandler(ChatbotHandler): + @override + def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: + return MistralChatbot(generator, tokenizer) diff --git a/src/fairseq2/chatbots/static.py b/src/fairseq2/chatbots/static.py new file mode 100644 index 000000000..e9db21970 --- /dev/null +++ b/src/fairseq2/chatbots/static.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.chatbots.chatbot import Chatbot +from fairseq2.chatbots.handler import ChatbotHandler, ChatbotNotFoundError +from fairseq2.context import get_runtime_context +from fairseq2.data.text import TextTokenizer +from fairseq2.generation.generator import SequenceGenerator + + +def create_chatbot( + name: str, generator: SequenceGenerator, tokenizer: TextTokenizer +) -> Chatbot: + context = get_runtime_context() + + registry = context.get_registry(ChatbotHandler) + + try: + handler = registry.get(name) + except LookupError: + raise ChatbotNotFoundError(name) from None + + return handler.create(generator, tokenizer) diff --git a/src/fairseq2/checkpoint/manager.py b/src/fairseq2/checkpoint/manager.py index 2577f8e6e..0616b1e66 100644 --- a/src/fairseq2/checkpoint/manager.py +++ b/src/fairseq2/checkpoint/manager.py @@ -6,43 +6,60 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod -from contextlib import nullcontext +from collections.abc import Iterator, Mapping, Set +from contextlib import AbstractContextManager, nullcontext from pathlib import Path -from pickle import PickleError from shutil import rmtree -from typing import ( - AbstractSet, - Any, - ContextManager, - Dict, - Iterator, - List, - Mapping, - NoReturn, - Optional, - Tuple, - final, -) +from typing import final +from warnings import catch_warnings -import yaml from torch.distributed._shard import load_with_process_group from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType from torch.nn import Module +from typing_extensions import override +from fairseq2.error import InvalidOperationError, NotSupportedError from fairseq2.gang import Gang -from fairseq2.logging import get_log_writer -from fairseq2.typing import CPU, DataClass, override -from fairseq2.utils.dataclass import to_safe_dict -from fairseq2.utils.file import TensorDumper, TensorLoader, dump_tensors, load_tensors - -log = get_log_writer(__name__) +from fairseq2.logging import log +from fairseq2.typing import CPU +from fairseq2.utils.file import ( + TensorDumpError, + TensorLoadError, + dump_torch_tensors, + load_torch_tensors, +) +from fairseq2.utils.structured import unstructure +from fairseq2.utils.yaml import dump_yaml class CheckpointManager(ABC): """Saves and loads training checkpoints.""" + @abstractmethod + def save_model_metadata( + self, + *, + base_asset: str | None = None, + family: str | None = None, + config: object = None, + ) -> None: + """Set the model metadata. + + :param base_asset: + The name of the asset that the model is based on. + :param family: + The family of the model. + :param config: + The configuration of the model. + """ + + @abstractmethod + def save_tokenizer_metadata(self, name: str) -> None: + ... + @abstractmethod def begin_checkpoint(self, step_nr: int) -> None: """Begin a transactional checkpoint operation. @@ -54,10 +71,10 @@ def begin_checkpoint(self, step_nr: int) -> None: @abstractmethod def save_state( self, - state: Mapping[str, Any], + state: Mapping[str, object], *, model_key: str = "model", - replicated_keys: Optional[AbstractSet[str]] = None, + replicated_keys: Set[str] | None = None, ) -> None: """Save the training state. @@ -71,7 +88,7 @@ def save_state( """ @abstractmethod - def save_metadata(self, metadata: Mapping[str, Any]) -> None: + def save_metadata(self, metadata: Mapping[str, object]) -> None: """Save ``metadata`` associated with the checkpoint. :param metadata: @@ -79,7 +96,7 @@ def save_metadata(self, metadata: Mapping[str, Any]) -> None: """ @abstractmethod - def save_score(self, score: Optional[float]) -> None: + def save_score(self, score: float | None) -> None: """Save the score of the checkpoint.""" @abstractmethod @@ -91,11 +108,11 @@ def commit_checkpoint(self) -> None: """Commit the checkpoint after which it will be considered saved.""" @abstractmethod - def load_checkpoint(self, step_nr: int) -> Dict[str, Any]: + def load_checkpoint(self, step_nr: int) -> dict[str, object]: """Load the checkpoint of the specified training step.""" @abstractmethod - def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: + def load_last_checkpoint(self) -> tuple[int, dict[str, object]]: """Load the last checkpoint in the training. :returns: @@ -104,7 +121,7 @@ def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: """ @abstractmethod - def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]: + def load_metadata(self, step_nr: int) -> dict[str, object] | None: """Load the checkpoint metadata of the specified training step.""" @abstractmethod @@ -132,7 +149,9 @@ def keep_last_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> No """ @abstractmethod - def keep_best_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> None: + def keep_best_n_checkpoints( + self, n: int, *, preserve_model: bool = False, lower_better: bool = False + ) -> None: """Delete all but the best ``n`` checkpoints based on their score. :param n: @@ -142,7 +161,7 @@ def keep_best_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> No """ @abstractmethod - def has_checkpoint(self, step_nr: Optional[int] = None) -> bool: + def has_checkpoint(self, step_nr: int | None = None) -> bool: """Return ``True`` if the manager holds a checkpoint. :param step_nr: @@ -151,7 +170,7 @@ def has_checkpoint(self, step_nr: Optional[int] = None) -> bool: """ @abstractmethod - def get_step_numbers(self) -> List[int]: + def get_step_numbers(self) -> list[int]: """Return the numbers of the training steps that have a checkpoint.""" @@ -164,21 +183,15 @@ class FileCheckpointManager(CheckpointManager): _dp_gang: Gang _num_shards: int _shard_suffix: str - _tensor_loader: TensorLoader - _tensor_dumper: TensorDumper - _lower_score_better: bool - _checkpoint_step_nr: Optional[int] + _checkpoint_step_nr: int | None def __init__( self, checkpoint_dir: Path, gang: Gang, *, - dp_gang: Optional[Gang] = None, - tp_gang: Optional[Gang] = None, - tensor_loader: Optional[TensorLoader] = None, - tensor_dumper: Optional[TensorDumper] = None, - lower_score_better: bool = False, + dp_gang: Gang | None = None, + tp_gang: Gang | None = None, ) -> None: """ :param checkpoint_dir: @@ -190,12 +203,6 @@ def __init__( :param tp_gang: The gang used for tensor parallelism. Must be specified if ``dp_gang`` is not ``None``. - :param tensor_loader: - The tensor loader to load checkpoints into memory. - :param tensor_dumper: - The tensor dumper to save checkpoints into file. - :param lower_score_better: - If ``True``, lower scores are considered better. """ self._checkpoint_dir = checkpoint_dir.expanduser().resolve() @@ -217,34 +224,18 @@ def __init__( elif dp_gang is not None or tp_gang is not None: raise ValueError("`dp_gang` and `tp_gang` must be both specified.") - self._tensor_loader = tensor_loader or load_tensors - self._tensor_dumper = tensor_dumper or dump_tensors - - self._lower_score_better = lower_score_better - self._checkpoint_step_nr = None + @override def save_model_metadata( self, *, - base_asset: Optional[str] = None, - family: Optional[str] = None, - config: Optional[DataClass] = None, - tokenizer_name: Optional[str] = None, + base_asset: str | None = None, + family: str | None = None, + config: object = None, ) -> None: - """Set the model metadata. - - :param base_asset: - The name of the asset that the model is based on. - :param family: - The family of the model. - :param config: - The configuration of the model. - :param tokenizer_name: - The name of the tokenizer that the model is trained with. - """ if self._root_gang.rank == 0: - metadata: Dict[str, Any] = {"name": "checkpoint"} + metadata: dict[str, object] = {"name": "checkpoint"} if base_asset is not None: metadata["base"] = base_asset @@ -253,29 +244,53 @@ def save_model_metadata( metadata["model_family"] = family if config is not None: - metadata["model_config"] = to_safe_dict(config) + metadata["model_config"] = unstructure(config) if self._num_shards != 1: metadata["num_shards"] = self._num_shards - if tokenizer_name is not None: - metadata["tokenizer_ref"] = tokenizer_name + metadata["tokenizer_ref"] = "checkpoint_tokenizer" try: self._checkpoint_dir.mkdir(parents=True, exist_ok=True) except OSError as ex: - raise RuntimeError( - "The model metadata cannot be saved. See nested exception for details." + raise CheckpointError( + f"The '{self._checkpoint_dir}' directory cannot be created. See the nested exception for details." ) from ex metadata_file = self._checkpoint_dir.joinpath("model.yaml") try: - with metadata_file.open("w") as fp: - yaml.safe_dump(metadata, fp, sort_keys=False) + dump_yaml(metadata, metadata_file) except OSError as ex: - raise RuntimeError( - "The model metadata cannot be saved. See nested exception for details." + raise CheckpointError( + f"The model metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details." + ) from ex + + self._root_gang.barrier() + + @override + def save_tokenizer_metadata(self, name: str) -> None: + if self._root_gang.rank == 0: + metadata: dict[str, object] = { + "name": "checkpoint_tokenizer", + "tokenizer_ref": name, + } + + try: + self._checkpoint_dir.mkdir(parents=True, exist_ok=True) + except OSError as ex: + raise CheckpointError( + f"The '{self._checkpoint_dir}' directory cannot be created. See the nested exception for details." + ) from ex + + metadata_file = self._checkpoint_dir.joinpath("tokenizer.yaml") + + try: + dump_yaml(metadata, metadata_file) + except OSError as ex: + raise CheckpointError( + f"The tokenizer metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details." ) from ex self._root_gang.barrier() @@ -283,9 +298,14 @@ def save_model_metadata( @override def begin_checkpoint(self, step_nr: int) -> None: if self._checkpoint_step_nr is not None: - raise ValueError("`begin_checkpoint()` has already been called.") + raise InvalidOperationError("`begin_checkpoint()` has already been called.") - self.delete_checkpoint(step_nr, missing_ok=True) + try: + self.delete_checkpoint(step_nr, missing_ok=True) + except CheckpointError as ex: + raise CheckpointError( + f"The previous checkpoint of training step {step_nr} cannot be deleted. See the nested exception for details." + ) from ex if self._root_gang.rank == 0: tmp_step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}.tmp") @@ -293,8 +313,8 @@ def begin_checkpoint(self, step_nr: int) -> None: try: tmp_step_dir.mkdir(parents=True) except OSError as ex: - raise RuntimeError( - f"The checkpoint directory of training step {step_nr} cannot be created. See nested exception for details." + raise CheckpointError( + f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be created. See the nested exception for details." ) from ex self._root_gang.barrier() @@ -304,18 +324,13 @@ def begin_checkpoint(self, step_nr: int) -> None: @override def save_state( self, - state: Mapping[str, Any], + state: Mapping[str, object], *, model_key: str = "model", - replicated_keys: Optional[AbstractSet[str]] = None, + replicated_keys: Set[str] | None = None, ) -> None: step_nr = self._get_checkpoint_step_nr() - def raise_error(cause: Exception) -> NoReturn: - raise RuntimeError( - f"The checkpoint of training step {step_nr} cannot be saved. See nested exception for details." - ) from cause - tmp_step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}.tmp") # Copy `state`. In case we fail, it should stay intact. @@ -342,11 +357,13 @@ def model_replicated() -> bool: model_file = tmp_step_dir.joinpath(f"model{self._shard_suffix}.pt") try: - self._tensor_dumper( + dump_torch_tensors( {model_key: state_dict, "model_key": model_key}, model_file ) - except (RuntimeError, OSError, PickleError) as ex: - raise_error(ex) + except TensorDumpError as ex: + raise CheckpointError( + f"The replicated model state of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." + ) from ex self._root_gang.barrier() @@ -370,9 +387,11 @@ def model_replicated() -> bool: ) try: - self._tensor_dumper(replicated_part, replicated_file) - except (RuntimeError, OSError, PickleError) as ex: - raise_error(ex) + dump_torch_tensors(replicated_part, replicated_file) + except TensorDumpError as ex: + raise CheckpointError( + f"The replicated checkpoint state of training step {step_nr} cannot be saved to the '{replicated_file}' file. See the nested exception for details." + ) from ex else: if "*" in replicated_keys: rank_part.clear() @@ -397,14 +416,16 @@ def model_replicated() -> bool: ) try: - self._tensor_dumper(rank_part, rank_file) - except (RuntimeError, OSError, PickleError) as ex: - raise_error(ex) + dump_torch_tensors(rank_part, rank_file) + except TensorDumpError as ex: + raise CheckpointError( + f"The checkpoint state of training step {step_nr} cannot be saved to the '{rank_file}' file. See the nested exception for details." + ) from ex self._root_gang.barrier() @override - def save_metadata(self, metadata: Mapping[str, Any]) -> None: + def save_metadata(self, metadata: Mapping[str, object]) -> None: step_nr = self._get_checkpoint_step_nr() if metadata is None: @@ -416,16 +437,16 @@ def save_metadata(self, metadata: Mapping[str, Any]) -> None: ) try: - self._tensor_dumper(metadata, metadata_file) - except (RuntimeError, OSError, PickleError) as ex: - raise RuntimeError( - f"The checkpoint metadata of training step {step_nr} cannot be saved. See nested exception for details." + dump_torch_tensors(metadata, metadata_file) + except TensorDumpError as ex: + raise CheckpointError( + f"The checkpoint metadata of training step {step_nr} cannot be saved to the '{metadata_file}' file. See the nested exception for details." ) from ex self._root_gang.barrier() @override - def save_score(self, score: Optional[float]) -> None: + def save_score(self, score: float | None) -> None: step_nr = self._get_checkpoint_step_nr() if self._root_gang.rank == 0: @@ -435,9 +456,8 @@ def save_score(self, score: Optional[float]) -> None: with score_file.open("w") as fp: fp.write(f"{score}\n") except OSError as ex: - raise RuntimeError( - f"The checkpoint score of training step {step_nr} cannot be saved. See nested exception for details.", - step_nr, + raise CheckpointError( + f"The checkpoint score of training step {step_nr} cannot be saved to the '{score_file}' file. See the nested exception for details." ) from ex self._root_gang.barrier() @@ -448,12 +468,17 @@ def save_consolidated_fsdp_model(self, model: Module) -> None: log.info("Extracting consolidated model state.") - with FSDP.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - state_dict = model.state_dict() + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + with FSDP.state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ), + ): + state_dict = model.state_dict() self._root_gang.barrier() @@ -465,12 +490,12 @@ def save_consolidated_fsdp_model(self, model: Module) -> None: ) try: - self._tensor_dumper( + dump_torch_tensors( {"model": state_dict, "model_key": "model"}, model_file ) - except (RuntimeError, OSError, PickleError) as ex: - raise RuntimeError( - f"The consolidated FSDP model of training step {step_nr} cannot be saved. See nested exception for details." + except TensorDumpError as ex: + raise CheckpointError( + f"The consolidated FSDP model of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." ) from ex self._root_gang.barrier() @@ -487,8 +512,8 @@ def commit_checkpoint(self) -> None: try: tmp_step_dir.replace(step_dir) except OSError as ex: - raise RuntimeError( - f"The checkpoint of training step {step_nr} cannot be saved. See nested exception for details." + raise CheckpointError( + f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be committed. See the nested exception for details." ) from ex self._root_gang.barrier() @@ -498,37 +523,34 @@ def commit_checkpoint(self) -> None: def _get_checkpoint_step_nr(self) -> int: step_nr = self._checkpoint_step_nr if step_nr is None: - raise ValueError("`begin_checkpoint()` must be called first.") + raise InvalidOperationError("`begin_checkpoint()` must be called first.") return step_nr @override - def load_checkpoint(self, step_nr: int) -> Dict[str, Any]: - def raise_error(cause: Exception) -> NoReturn: - raise RuntimeError( - f"The checkpoint of training step {step_nr} cannot be loaded. See nested exception for details." - ) from cause - - def maybe_with_dp_process_group() -> ContextManager[None]: + def load_checkpoint(self, step_nr: int) -> dict[str, object]: + def maybe_with_dp_process_group() -> AbstractContextManager[None]: try: pg = self._dp_gang.as_process_group() - except RuntimeError: + except NotSupportedError: return nullcontext() return load_with_process_group(pg) step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}") - def load_part(filename: str) -> Dict[str, Any]: + def load_part(filename: str) -> dict[str, object]: with maybe_with_dp_process_group(): # Required for `ShardedTensor`. try: - part = self._tensor_loader( + part = load_torch_tensors( step_dir.joinpath(filename), map_location=CPU ) except FileNotFoundError: part = {} - except (RuntimeError, OSError, PickleError) as ex: - raise_error(ex) + except TensorLoadError as ex: + raise CheckpointError( + f"The '{filename}' checkpoint file of training step {step_nr} cannot be loaded. See the nested exception for details." + ) from ex self._root_gang.barrier() @@ -544,10 +566,7 @@ def load_part(filename: str) -> Dict[str, Any]: # Consolidate the checkpoint parts. checkpoint.update(part) - try: - model_key = checkpoint["model_key"] - except KeyError: - model_key = None + model_key = checkpoint.get("model_key") # If we don't have the model state in the checkpoint so far, it means it # was replicated. @@ -562,12 +581,14 @@ def load_part(filename: str) -> Dict[str, Any]: pass if not checkpoint: - raise CheckpointNotFoundError(f"Training step {step_nr} has no checkpoint.") + raise CheckpointNotFoundError( + f"The checkpoint of training step {step_nr} is not found." + ) return checkpoint @override - def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: + def load_last_checkpoint(self) -> tuple[int, dict[str, object]]: step_numbers = self.get_step_numbers() if not step_numbers: raise CheckpointNotFoundError("No checkpoint found.") @@ -579,18 +600,18 @@ def load_last_checkpoint(self) -> Tuple[int, Dict[str, Any]]: return last_step_nr, checkpoint @override - def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]: + def load_metadata(self, step_nr: int) -> dict[str, object] | None: metadata_file = self._checkpoint_dir.joinpath( f"step_{step_nr}/metadata{self._shard_suffix}.pt" ) try: - metadata = self._tensor_loader(metadata_file, map_location=CPU) + metadata = load_torch_tensors(metadata_file, map_location=CPU) except FileNotFoundError: metadata = None - except (RuntimeError, OSError, PickleError) as ex: - raise RuntimeError( - f"The checkpoint metadata of training step {step_nr} cannot be loaded. See nested exception for details." + except TensorLoadError as ex: + raise CheckpointError( + f"The checkpoint metadata of training step {step_nr} cannot be loaded from the '{metadata_file}' file. See the nested exception for details." ) from ex self._root_gang.barrier() @@ -601,24 +622,25 @@ def load_metadata(self, step_nr: int) -> Optional[Dict[str, Any]]: def delete_checkpoint( self, step_nr: int, *, missing_ok: bool = False, preserve_model: bool = False ) -> None: - def raise_error(cause: Exception) -> NoReturn: - raise RuntimeError( - f"The checkpoint of training step {step_nr} cannot be deleted. See nested exception for details." - ) from cause - if self._root_gang.rank == 0: step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}") # Delete the temporary checkpoint directory if it exists. + tmp_step_dir = step_dir.with_suffix(".tmp") + try: - rmtree(step_dir.with_suffix(".tmp")) + rmtree(tmp_step_dir) except OSError as ex: if not isinstance(ex, FileNotFoundError): - raise_error(ex) + raise CheckpointError( + f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be deleted. See the nested exception for details." + ) if not step_dir.exists(): if not missing_ok: - raise RuntimeError(f"Training step {step_nr} has no checkpoint.") + raise CheckpointNotFoundError( + f"The '{step_dir}' checkpoint directory of training step {step_nr} is not found." + ) self._root_gang.barrier() @@ -635,13 +657,17 @@ def raise_error(cause: Exception) -> NoReturn: pt_file.unlink() except OSError as ex: if not isinstance(ex, FileNotFoundError): - raise_error(ex) + raise CheckpointError( + f"The '{pt_file}' checkpoint file of training step {step_nr} cannot be deleted. See the nested exception for details." + ) else: try: rmtree(step_dir) except OSError as ex: if not missing_ok or not isinstance(ex, FileNotFoundError): - raise_error(ex) + raise CheckpointError( + f"The '{step_dir}' checkpoint directory of training step {step_nr} cannot be deleted. See the nested exception for details." + ) self._root_gang.barrier() @@ -660,7 +686,9 @@ def keep_last_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> No self.delete_checkpoint(step_nr, preserve_model=preserve_model) @override - def keep_best_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> None: + def keep_best_n_checkpoints( + self, n: int, *, preserve_model: bool = False, lower_better: bool = False + ) -> None: if n == 0: raise ValueError("`n` must be greater than zero.") @@ -670,7 +698,7 @@ def keep_best_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> No last_step_nr = step_numbers[-1] - scores = self._load_scores(step_numbers) + scores = self._load_scores(step_numbers, lower_better) if not scores: return @@ -681,7 +709,9 @@ def keep_best_n_checkpoints(self, n: int, *, preserve_model: bool = False) -> No if step_nr != last_step_nr: self.delete_checkpoint(step_nr, preserve_model=preserve_model) - def _load_scores(self, step_numbers: List[int]) -> List[Tuple[float, int]]: + def _load_scores( + self, step_numbers: list[int], lower_better: bool + ) -> list[tuple[float, int]]: scores = [] for step_nr in step_numbers: @@ -691,20 +721,20 @@ def _load_scores(self, step_numbers: List[int]) -> List[Tuple[float, int]]: with score_file.open() as fp: line = fp.readline() except OSError as ex: - raise RuntimeError( - f"The score of training step {step_nr} cannot be loaded. See nested exception for details." + raise CheckpointError( + f"The score of training step {step_nr} cannot be loaded from the '{score_file}' file. See the nested exception for details." ) from ex try: score = float(line) - except ValueError as ex: - raise RuntimeError( - f"The score of training step {step_nr} cannot be loaded. See nested exception for details." - ) from ex + except ValueError: + raise CheckpointError( + f"The score of training step {step_nr} cannot be parsed as a floating-point number." + ) from None scores.append((score, step_nr)) - if self._lower_score_better: + if lower_better: scores.sort(key=lambda e: (-e[0], e[1])) else: scores.sort() @@ -712,7 +742,7 @@ def _load_scores(self, step_numbers: List[int]) -> List[Tuple[float, int]]: return scores @override - def has_checkpoint(self, step_nr: Optional[int] = None) -> bool: + def has_checkpoint(self, step_nr: int | None = None) -> bool: it = self._iter_step_numbers() if step_nr is None: @@ -721,7 +751,7 @@ def has_checkpoint(self, step_nr: Optional[int] = None) -> bool: return step_nr in it @override - def get_step_numbers(self) -> List[int]: + def get_step_numbers(self) -> list[int]: step_numbers = list(self._iter_step_numbers()) step_numbers.sort() @@ -741,10 +771,14 @@ def _iter_step_numbers(self) -> Iterator[int]: yield step_nr except OSError as ex: - raise RuntimeError( - "The base checkpoint directory cannot be traversed. See nested exception for details." + raise CheckpointError( + f"The base '{self._checkpoint_dir}' checkpoint directory cannot be traversed. See the nested exception for details." ) from ex -class CheckpointNotFoundError(RuntimeError): - """Raised when a checkpoint is not found.""" +class CheckpointError(Exception): + pass + + +class CheckpointNotFoundError(CheckpointError): + pass diff --git a/src/fairseq2/checkpoint/metadata_provider.py b/src/fairseq2/checkpoint/metadata_provider.py index c0c42100e..fa24744f2 100644 --- a/src/fairseq2/checkpoint/metadata_provider.py +++ b/src/fairseq2/checkpoint/metadata_provider.py @@ -7,63 +7,64 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, final +from typing import TypeAlias, final -from fairseq2.assets.metadata_provider import ( +from typing_extensions import override + +from fairseq2.assets import ( AbstractAssetMetadataProvider, AssetMetadataError, load_metadata_file, ) -from fairseq2.typing import override +from fairseq2.utils.yaml import load_yaml @final -class CheckpointModelMetadataProvider(AbstractAssetMetadataProvider): +class FileCheckpointMetadataProvider(AbstractAssetMetadataProvider): """Provides checkpoint model metadata saved by a :class:`FileCheckpointManager.`""" _checkpoint_dir: Path - _lower_score_better: bool - def __init__( - self, checkpoint_dir: Path, *, lower_score_better: bool = False - ) -> None: + def __init__(self, checkpoint_dir: Path) -> None: """ :param checkpoint_dir: The base directory under which the checkpoints are stored. - :param lower_score_better: - If ``True``, lower scores are considered better. """ super().__init__() - self._checkpoint_dir = checkpoint_dir.expanduser().resolve() - - self._lower_score_better = lower_score_better + self._checkpoint_dir = checkpoint_dir @override - def _load_cache(self) -> Dict[str, Dict[str, Any]]: - metadata_file = self._checkpoint_dir.joinpath("model.yaml") - if not metadata_file.exists(): - raise AssetMetadataError( - f"The checkpoint model metadata (model.yaml) cannot be found under {self._checkpoint_dir}. Make sure that the specified directory is the *base* checkpoint directory used during training (i.e. directory passed to `FileCheckpointManager.save_model_metadata()`)." - ) + def _load_cache(self) -> dict[str, dict[str, object]]: + cache: dict[str, dict[str, object]] = {} - cache = dict(load_metadata_file(metadata_file)) + self._load_model(cache) + + self._load_tokenizer(cache) + + return cache + + def _load_model(self, cache: dict[str, dict[str, object]]) -> None: + checkpoint_dir = self._checkpoint_dir.expanduser().resolve() + + metadata_file = checkpoint_dir.joinpath("model.yaml") + + for name, metadata in load_metadata_file(metadata_file, load_yaml): + cache[name] = metadata try: metadata = cache["checkpoint@"] - except KeyError as ex: + except KeyError: raise AssetMetadataError( - "The checkpoint model metadata has an invalid format." - ) from ex + "The checkpoint metadata does not have a 'checkpoint@' entry." + ) from None - try: - num_shards = int(metadata["num_shards"]) - except KeyError: - num_shards = 1 - except ValueError as ex: + num_shards = metadata.get("num_shards", 1) + + if not isinstance(num_shards, int) or num_shards < 1: raise AssetMetadataError( - "The checkpoint model metadata has an invalid format." - ) from ex + "The 'num_shards' value in the checkpoint metadata is not a positive integer." + ) if num_shards == 1: filename = "model.pt" @@ -78,7 +79,7 @@ def add_checkpoint_metadata(name: str, path: Path) -> None: scores = [] try: - for step_dir in self._checkpoint_dir.glob("step_*"): + for step_dir in checkpoint_dir.glob("step_*"): if not step_dir.is_dir(): continue @@ -100,38 +101,45 @@ def add_checkpoint_metadata(name: str, path: Path) -> None: with score_file.open() as fp: line = fp.readline() except OSError as ex: - raise RuntimeError( - f"The score of training step {step_nr} cannot be loaded. See nested exception for details." + raise AssetMetadataError( + f"The score of the training step {step_nr} cannot be loaded from the '{score_file}' file. See the nested exception for details." ) from ex try: score = float(line) - except ValueError as ex: - raise RuntimeError( - f"The score of training step {step_nr} cannot be loaded. See nested exception for details." - ) from ex + except ValueError: + raise AssetMetadataError( + f"The score of the training step {step_nr} cannot be parsed as a floating-point number." + ) from None scores.append((score, step_nr)) except OSError as ex: - raise RuntimeError( - "The base checkpoint directory cannot be traversed. See nested exception for details." + raise AssetMetadataError( + f"The base '{checkpoint_dir}' checkpoint directory cannot be traversed. See the nested exception for details." ) from ex if max_step_nr >= 0: - last_model_file = self._checkpoint_dir.joinpath( - f"step_{max_step_nr}/{filename}" - ) + last_model_file = checkpoint_dir.joinpath(f"step_{max_step_nr}/{filename}") add_checkpoint_metadata("last_checkpoint@", last_model_file) - if self._lower_score_better: - scores.sort(key=lambda e: (-e[0], e[1])) - else: - scores.sort() + scores.sort() - for rank, (_, step_nr) in enumerate(reversed(scores)): - model_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/{filename}") + last_idx = len(scores) - 1 - add_checkpoint_metadata(f"checkpoint_best_{rank}@", model_file) + for i, (_, step_nr) in enumerate(scores): + model_file = checkpoint_dir.joinpath(f"step_{step_nr}/{filename}") - return cache + add_checkpoint_metadata(f"checkpoint_lowest_{i}@", model_file) + add_checkpoint_metadata(f"checkpoint_highest_{last_idx - i}@", model_file) + + def _load_tokenizer(self, cache: dict[str, dict[str, object]]) -> None: + checkpoint_dir = self._checkpoint_dir.expanduser().resolve() + + metadata_file = checkpoint_dir.joinpath("tokenizer.yaml") + if metadata_file.exists(): + for name, metadata in load_metadata_file(metadata_file, load_yaml): + cache[name] = metadata + + +CheckpointModelMetadataProvider: TypeAlias = FileCheckpointMetadataProvider # compat diff --git a/src/fairseq2/config_registry.py b/src/fairseq2/config_registry.py index 3c874b53b..54d9d03ea 100644 --- a/src/fairseq2/config_registry.py +++ b/src/fairseq2/config_registry.py @@ -6,67 +6,103 @@ from __future__ import annotations -from typing import AbstractSet, Callable, Dict, Generic, Protocol, TypeVar, final +from abc import ABC, abstractmethod +from collections.abc import Callable, Set +from functools import cached_property +from typing import Any, Generic, Protocol, TypeVar, final, get_args -from fairseq2.typing import DataClass +from typing_extensions import override -ConfigT = TypeVar("ConfigT", bound=DataClass) +from fairseq2.error import AlreadyExistsError -ConfigT_co = TypeVar("ConfigT_co", bound=DataClass, covariant=True) +ConfigT = TypeVar("ConfigT") +ConfigT_co = TypeVar("ConfigT_co", covariant=True) -class ConfigFactory(Protocol[ConfigT_co]): - """Constructs instances of ``ConfigT``.""" +class ConfigProvider(ABC, Generic[ConfigT_co]): + """Provides configurations of type ``ConfigT``.""" + + @abstractmethod + def get(self, name: str) -> ConfigT_co: + """Return the configuration of ``name``.""" + + @abstractmethod + def names(self) -> Set[str]: + """Return the names of all configurations.""" + + @property + @abstractmethod + def config_kls(self) -> type[ConfigT_co]: + """The type of the configuration.""" + + +class ConfigSupplier(Protocol[ConfigT_co]): def __call__(self) -> ConfigT_co: ... @final -class ConfigRegistry(Generic[ConfigT]): +class ConfigRegistry(ConfigProvider[ConfigT]): """Holds configurations of type ``ConfigT``.""" - _configs: Dict[str, ConfigFactory[ConfigT]] + _configs: dict[str, ConfigSupplier[ConfigT]] def __init__(self) -> None: self._configs = {} + @override def get(self, name: str) -> ConfigT: - """Return the configuration of ``name``.""" try: return self._configs[name]() except KeyError: - raise ValueError( - f"`name` must be a registered configuration name, but is '{name}' instead." - ) from None + raise ConfigNotFoundError(name) from None - def register(self, name: str, config_factory: ConfigFactory[ConfigT]) -> None: + def register(self, name: str, supplier: ConfigSupplier[ConfigT]) -> None: """Register a new configuration. - :param name: - The name of the configuration. - :param config_factory: - The factory to construct configurations. + :param name: The name of the configuration. + :param config_supplier: The configuration supplier. """ if name in self._configs: - raise ValueError( - f"`name` must be a unique configuration name, but '{name}' has already a registered configuration factory." + raise AlreadyExistsError( + f"The registry has already a configuration named '{name}'." ) - self._configs[name] = config_factory + self._configs[name] = supplier def decorator( self, name: str - ) -> Callable[[ConfigFactory[ConfigT]], ConfigFactory[ConfigT]]: - """Register ``name`` with the decorated configuration factory.""" + ) -> Callable[[ConfigSupplier[ConfigT]], ConfigSupplier[ConfigT]]: + """Register ``name`` with the decorated configuration supplier.""" - def register(config_factory: ConfigFactory[ConfigT]) -> ConfigFactory[ConfigT]: - self.register(name, config_factory) + def register(supplier: ConfigSupplier[ConfigT]) -> ConfigSupplier[ConfigT]: + self.register(name, supplier) - return config_factory + return supplier return register - def names(self) -> AbstractSet[str]: - """Return the names of all configurations.""" + @override + def names(self) -> Set[str]: return self._configs.keys() + + @override + @property + def config_kls(self) -> type[ConfigT]: + return self._config_kls # type: ignore[no-any-return] + + @cached_property + def _config_kls(self) -> Any: + kls_args = get_args(self.__orig_class__) # type: ignore[attr-defined] + + return kls_args[0] + + +class ConfigNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a registered configuration name.") + + self.name = name diff --git a/src/fairseq2/console.py b/src/fairseq2/console.py deleted file mode 100644 index 6fbfca9e1..000000000 --- a/src/fairseq2/console.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from typing import Optional - -from rich import get_console as get_rich_console -from rich.console import Console - -_console: Optional[Console] = None - - -def get_console() -> Console: - """Return the ``stdout`` Rich console.""" - global _console - - if _console is None: - _console = get_rich_console() - - return _console - - -def set_console(console: Console) -> None: - """Set the ``stdout`` Rich console.""" - global _console - - _console = console - - -_error_console: Optional[Console] = None - - -def get_error_console() -> Console: - """Return the ``stderr`` Rich console.""" - global _error_console - - if _error_console is None: - _error_console = Console(stderr=True, highlight=False) - - return _error_console - - -def set_error_console(console: Console) -> None: - """Get the ``stderr`` Rich console.""" - global _error_console - - _error_console = console diff --git a/src/fairseq2/context.py b/src/fairseq2/context.py new file mode 100644 index 000000000..4872f81ef --- /dev/null +++ b/src/fairseq2/context.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Hashable, Iterable +from typing import Any, Generic, Mapping, TypeVar, final + +from typing_extensions import override + +from fairseq2.assets import AssetDownloadManager, StandardAssetStore +from fairseq2.error import AlreadyExistsError + +T = TypeVar("T") + + +@final +class RuntimeContext: + _asset_store: StandardAssetStore + _asset_download_manager: AssetDownloadManager + _registries: Mapping[type, Registry[Any]] + + def __init__( + self, + asset_store: StandardAssetStore, + asset_download_manager: AssetDownloadManager, + ) -> None: + self._asset_store = asset_store + self._asset_download_manager = asset_download_manager + + self._registries = defaultdict(Registry) + + @property + def asset_store(self) -> StandardAssetStore: + return self._asset_store + + @property + def asset_download_manager(self) -> AssetDownloadManager: + return self._asset_download_manager + + def get_registry(self, kls: type[T]) -> Registry[T]: + return self._registries[kls] + + +T_co = TypeVar("T_co", covariant=True) + + +class Provider(ABC, Generic[T_co]): + @abstractmethod + def get(self, key: Hashable) -> T_co: + ... + + @abstractmethod + def get_all(self) -> Iterable[tuple[Hashable, T_co]]: + ... + + +@final +class Registry(Provider[T]): + _entries: dict[Hashable, T] + + def __init__(self) -> None: + self._entries = {} + + @override + def get(self, key: Hashable) -> T: + try: + return self._entries[key] + except KeyError: + raise LookupError(f"The registry does not contain a '{key}' key.") from None + + @override + def get_all(self) -> Iterable[tuple[Hashable, T]]: + return self._entries.items() + + def register(self, key: Hashable, value: T) -> None: + if key in self._entries: + raise AlreadyExistsError(f"The registry already contains a '{key}' key.") + + self._entries[key] = value + + +_default_context: RuntimeContext | None = None + + +def set_runtime_context(context: RuntimeContext) -> None: + global _default_context + + _default_context = context + + +def get_runtime_context() -> RuntimeContext: + if _default_context is None: + raise RuntimeError( + "fairseq2 is not initialized. Make sure to call `fairseq2.setup_fairseq2()`." + ) + + return _default_context diff --git a/src/fairseq2/data/audio.py b/src/fairseq2/data/audio.py index 45a241cac..3b54db9e8 100644 --- a/src/fairseq2/data/audio.py +++ b/src/fairseq2/data/audio.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, TypedDict, Union, final +from typing import TYPE_CHECKING, TypedDict, final from fairseq2n import DOC_MODE from torch import Tensor @@ -22,8 +22,8 @@ class AudioDecoder: def __init__( self, keepdim: bool = False, - dtype: Optional[DataType] = None, - device: Optional[Device] = None, + dtype: DataType | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> None: ... @@ -40,8 +40,8 @@ def __init__( channel_last: bool = False, standardize: bool = False, keep_waveform: bool = False, - dtype: Optional[DataType] = None, - device: Optional[Device] = None, + dtype: DataType | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> None: ... @@ -70,7 +70,7 @@ class AudioDecoderOutput(TypedDict): class WaveformToFbankInput(TypedDict): waveform: Tensor - sample_rate: Union[int, float] + sample_rate: int | float class WaveformToFbankOutput(TypedDict): diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index 6e99043f3..e1e4a0c3c 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -6,23 +6,9 @@ from __future__ import annotations +from collections.abc import Callable, Iterator, Mapping, Sequence from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, - Sequence, - Tuple, - TypedDict, - TypeVar, - Union, - final, -) +from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, final from fairseq2n import DOC_MODE from torch import Tensor @@ -68,7 +54,7 @@ def is_broken(self) -> bool: :class:`DataPipelineError`. """ - def state_dict(self, strict: bool = True) -> Dict[str, Any]: + def state_dict(self, strict: bool = True) -> dict[str, Any]: """Return a dictionary containing the state of the data pipeline. The current position of the data pipeline can be restored by passing @@ -97,7 +83,7 @@ def concat(pipelines: Sequence[DataPipeline]) -> DataPipelineBuilder: """ @staticmethod - def constant(example: Any, key: Optional[str] = None) -> DataPipelineBuilder: + def constant(example: Any, key: str | None = None) -> DataPipelineBuilder: """Repeatedly yield ``example``. This pipeline is pseudo-infinite; when used with functions @@ -105,8 +91,7 @@ def constant(example: Any, key: Optional[str] = None) -> DataPipelineBuilder: it will yield examples only as long as other pipelines yield examples. - See :ref:`reference/data:pseudo-infinite and infinite pipelines` - for more details. + See :ref:`basics/data-pipeline/pipeline-types` for more details. :param example: Example to yield infinitely. @@ -118,7 +103,7 @@ def constant(example: Any, key: Optional[str] = None) -> DataPipelineBuilder: @staticmethod def count( - start: int = 0, step: int = 1, key: Optional[str] = None + start: int = 0, step: int = 1, key: str | None = None ) -> DataPipelineBuilder: """Count from ``start`` in steps of size ``step``. @@ -127,8 +112,7 @@ def count( it will yield examples only as long as other pipelines yield examples. - See :ref:`reference/data:pseudo-infinite and infinite pipelines` - for more details. + See :ref:`basics/data-pipeline/pipeline-types` for more details. :param start: Number to start counting from. @@ -148,6 +132,8 @@ def round_robin( ) -> DataPipelineBuilder: """Extract examples from ``pipelines`` in round robin. + See :ref:`basics/data-pipeline/combining-pipelines` for more details. + :param pipelines: The data pipelines to round robin. :param stop_at_shortest: @@ -161,14 +147,16 @@ def round_robin( @staticmethod def sample( pipelines: Sequence[DataPipeline], - weights: Optional[Sequence[float]] = None, - seed: Optional[int] = None, + weights: Sequence[float] | None = None, + seed: int | None = None, allow_repeats: bool = True, ) -> DataPipelineBuilder: """Extract examples from ``pipelines`` by sampling based on ``weights``. Circles around pipelines until all have reached their end at least once. + See :ref:`basics/data-pipeline/combining-pipelines` for more details. + :param data_pipelines: The data pipelines to sample from. :param weights: @@ -182,13 +170,15 @@ def sample( @staticmethod def zip( pipelines: Sequence[DataPipeline], - names: Optional[Sequence[str]] = None, + names: Sequence[str] | None = None, zip_to_shortest: bool = False, flatten: bool = False, disable_parallelism: bool = False, ) -> DataPipelineBuilder: """Zip together examples read from ``pipelines``. + See :ref:`basics/data-pipeline/combining-pipelines` for more details. + :param pipelines: The data pipelines to zip. :param names: @@ -211,6 +201,8 @@ class DataPipelineBuilder: def bucket(self, bucket_size: int, drop_remainder: bool = False) -> Self: """Combine a number of consecutive examples into a single example. + See :ref:`basics/data-pipeline/combining-pipelines` for more details. + :param bucket_size: The number of examples to combine. :param drop_remainder: @@ -220,8 +212,8 @@ def bucket(self, bucket_size: int, drop_remainder: bool = False) -> Self: def bucket_by_length( self, - bucket_sizes: Sequence[Tuple[int, int]], - selector: Optional[str] = None, + bucket_sizes: Sequence[tuple[int, int]], + selector: str | None = None, min_data_len: int = 1, skip_below_min_examples: bool = False, skip_above_max_examples: bool = False, @@ -231,9 +223,9 @@ def bucket_by_length( def collate( self, - pad_value: Optional[int] = None, + pad_value: int | None = None, pad_to_multiple: int = 1, - overrides: Optional[Sequence[CollateOptionsOverride]] = None, + overrides: Sequence[CollateOptionsOverride] | None = None, ) -> Self: """Concatenate a list of inputs into a single inputs. @@ -245,8 +237,12 @@ def dynamic_bucket( self, threshold: float, cost_fn: Callable[[Any], float], - min_num_examples: Optional[int] = None, - max_num_examples: Optional[int] = None, + bucket_creation_fn: ( + Callable[[Sequence[Any]], tuple[Sequence[Sequence[Any]], Sequence[Any]]] + | None + ) = None, + min_num_examples: int | None = None, + max_num_examples: int | None = None, drop_remainder: bool = False, ) -> Self: """Combine a number of consecutive examples into a single example @@ -260,6 +256,13 @@ def dynamic_bucket( Threshold for cumulative cost to trigger bucketing. :param cost_fn: Cost function that outputs cost for a particular example. + :param bucket_creation_fn: + Function for customizing bucket creation. Called with the bucket of + examples that caused the cost threshold to be exceeded. + Expected to return a tuple of ``(new_buckets, remainder)``, where + the internal buffer is set to ``remainder`` and ``new_buckets`` is + a list of buckets to be yielded. If ``None``, defaults to the + identity function. :param min_num_examples: Minimum number of examples per bucket. :param max_num_examples: @@ -280,8 +283,8 @@ def filter(self, predicate: Callable[[Any], Any]) -> Self: def map( self, - fn: Union[Callable[[Any], Any], Sequence[Callable[[Any], Any]]], - selector: Optional[str] = None, + fn: Callable[[Any], Any] | Sequence[Callable[[Any], Any]], + selector: str | None = None, num_parallel_calls: int = 1, ) -> Self: """Apply ``fn`` to each example. @@ -306,7 +309,7 @@ def map( :param selector: The column to apply the function to. Several columns can be specified by separating them with a ",". - See :ref:`reference/data:column syntax` for more details. + See :ref:`basics/data-pipeline/column-selection` for more details. :param num_parallel_calls: The number of examples to process in parallel. """ @@ -320,7 +323,7 @@ def prefetch(self, num_examples: int) -> Self: """ def repeat( - self, num_repeats: Optional[int] = None, reset_rng: bool = False + self, num_repeats: int | None = None, reset_rng: bool = False ) -> Self: """Repeats the sequence of pipeline examples ``num_repeats`` times. @@ -342,7 +345,7 @@ def shard( The number of shards. """ - def shuffle(self, shuffle_window: int, seed: Optional[int] = None) -> Self: + def shuffle(self, shuffle_window: int, seed: int | None = None) -> Self: """Shuffle examples using a fixed sized buffer. :param shuffle_window: @@ -370,13 +373,13 @@ def yield_from(self, fn: Callable[[Any], DataPipeline]) -> Self: def and_return(self, max_num_warnings: int = 0) -> DataPipeline: """Return a new :class:`DataPipeline` instance.""" - class DataPipelineError(RuntimeError): + class DataPipelineError(Exception): """Raised when an error occurs while reading from a data pipeline.""" def get_last_failed_example() -> Any: ... - def list_files(path: Path, pattern: Optional[str] = None) -> DataPipelineBuilder: + def list_files(path: Path, pattern: str | None = None) -> DataPipelineBuilder: """List recursively all files under ``path`` that matches ``pattern``. :param path: @@ -432,13 +435,13 @@ class CollateOptionsOverride: :param selector: The columns this overrides applies to. - See :ref:`reference/data:column syntax` for details on how to specify columns. + See :ref:`basics/data-pipeline/column-selection` for details on how to specify columns. """ def __init__( self, selector: str, - pad_value: Optional[int] = None, + pad_value: int | None = None, pad_to_multiple: int = 1, ) -> None: ... @@ -448,7 +451,7 @@ def selector(self) -> str: ... @property - def pad_value(self) -> Optional[int]: + def pad_value(self) -> int | None: ... @property @@ -491,9 +494,9 @@ class Collater: def __init__( self, - pad_value: Optional[int] = None, + pad_value: int | None = None, pad_to_multiple: int = 1, - overrides: Optional[Sequence[CollateOptionsOverride]] = None, + overrides: Sequence[CollateOptionsOverride] | None = None, ) -> None: ... @@ -520,8 +523,8 @@ class FileMapper: def __init__( self, - root_dir: Optional[Path] = None, - cached_fd_count: Optional[int] = None, + root_dir: Path | None = None, + cached_fd_count: int | None = None, ) -> None: ... @@ -538,10 +541,10 @@ def __call__(self, pathname: str) -> FileMapperOutput: """ ... - class ByteStreamError(RuntimeError): + class ByteStreamError(Exception): """Raised when a dataset file can't be read.""" - class RecordError(RuntimeError): + class RecordError(Exception): """Raised when a corrupt record is encountered while reading a dataset.""" else: @@ -609,7 +612,7 @@ def create_bucket_sizes( max_seq_len: int, min_seq_len: int = 1, num_seqs_multiple_of: int = 1, -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: """Create optimal bucket sizes for :meth:`DataPipeline.bucket_by_length`. :param max_num_elements: diff --git a/src/fairseq2/data/image.py b/src/fairseq2/data/image.py index 13fcef071..20588b8a0 100644 --- a/src/fairseq2/data/image.py +++ b/src/fairseq2/data/image.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, TypedDict, final +from typing import TYPE_CHECKING, TypedDict, final from fairseq2n import DOC_MODE from torch import Tensor @@ -20,7 +20,7 @@ class ImageDecoder: def __init__( self, - device: Optional[Device] = None, + device: Device | None = None, pin_memory: bool = False, ) -> None: ... diff --git a/src/fairseq2/data/memory.py b/src/fairseq2/data/memory.py index 15dfa44e9..f4b2f3625 100644 --- a/src/fairseq2/data/memory.py +++ b/src/fairseq2/data/memory.py @@ -7,12 +7,11 @@ from __future__ import annotations from array import array -from typing import TYPE_CHECKING, Optional, Union, final, overload +from typing import TYPE_CHECKING, TypeAlias, final, overload from fairseq2n import DOC_MODE -from typing_extensions import TypeAlias -Buffer: TypeAlias = Union[bytes, bytearray, memoryview, array] # type: ignore[type-arg] +Buffer: TypeAlias = bytes | bytearray | memoryview | array # type: ignore[type-arg] if TYPE_CHECKING or DOC_MODE: @@ -28,7 +27,7 @@ def __init__(self) -> None: def __init__(self, buffer: Buffer, copy: bool = False) -> None: ... - def __init__(self, buffer: Optional[Buffer] = None, copy: bool = False) -> None: + def __init__(self, buffer: Buffer | None = None, copy: bool = False) -> None: """ :param buffer: An object that supports the Python buffer protocol. diff --git a/src/fairseq2/data/parquet_v0/dataloader.py b/src/fairseq2/data/parquet_v0/dataloader.py index 6eac2a0f4..26b7de229 100644 --- a/src/fairseq2/data/parquet_v0/dataloader.py +++ b/src/fairseq2/data/parquet_v0/dataloader.py @@ -6,10 +6,11 @@ from __future__ import annotations +from collections.abc import Generator from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Generator, List, Optional, Union +from typing import Any import pyarrow as pa import pyarrow.parquet as pq @@ -41,24 +42,24 @@ class ParquetBasicDataloaderConfig: parquet_path: str """The path to parquet dataset file.""" - batch_size: Optional[int] = None + batch_size: int | None = None """The output batch size.""" - order_by_length: Optional[str] = None + order_by_length: str | None = None """The column in the dataset whose length will be used for batch ordering. This results in batches with relatively homogeneous values, typically to support optimal padding.""" - max_tokens: Optional[int] = None + max_tokens: int | None = None """Used with the ``order_by_length`` option to control the total number of padded tokens in each batch. Typically, this option is preferred over ``batch_size`` to reduce the memory footprint. """ - columns: Optional[List[str]] = None + columns: list[str] | None = None """The list of columns to load.""" - filters: Optional[Union[List[Any], pa.dataset.Expression]] = None + filters: list[Any] | pa.dataset.Expression | None = None """See https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression Some examples : @@ -119,7 +120,7 @@ class ParquetBasicDataloaderConfig: Since we rely on the external parallelism, this param is tuned off by default.""" - filesystem: Optional[pa.fs.FileSystem] = None + filesystem: pa.fs.FileSystem | None = None """The filesystem to read the Parquet files from. S3 example: >>> import s3fs >>> filesystem = s3fs.core.S3FileSystem(...) @@ -163,9 +164,11 @@ def inner_iterator(wrap_table: _TableWrapper) -> DataPipeline: columns=config.columns, split_to_row_groups=config.split_to_row_groups, filesystem=config.filesystem, - shuffle_window=2 * config.nb_prefetch * config.nb_parallel_fragments - if config.shuffle - else None, + shuffle_window=( + 2 * config.nb_prefetch * config.nb_parallel_fragments + if config.shuffle + else None + ), seed=config.seed, ) .shard(shard_idx=config.rank, num_shards=config.world_size) diff --git a/src/fairseq2/data/text/__init__.py b/src/fairseq2/data/text/__init__.py index 09b30c972..02ea1b5ac 100644 --- a/src/fairseq2/data/text/__init__.py +++ b/src/fairseq2/data/text/__init__.py @@ -9,51 +9,10 @@ from fairseq2.data.text.converters import StrSplitter as StrSplitter from fairseq2.data.text.converters import StrToIntConverter as StrToIntConverter from fairseq2.data.text.converters import StrToTensorConverter as StrToTensorConverter -from fairseq2.data.text.sentencepiece import ( - BasicSentencePieceTokenizer as BasicSentencePieceTokenizer, -) -from fairseq2.data.text.sentencepiece import ( - RawSentencePieceTokenizer as RawSentencePieceTokenizer, -) -from fairseq2.data.text.sentencepiece import ( - SentencePieceDecoder as SentencePieceDecoder, -) -from fairseq2.data.text.sentencepiece import ( - SentencePieceEncoder as SentencePieceEncoder, -) -from fairseq2.data.text.sentencepiece import SentencePieceModel as SentencePieceModel -from fairseq2.data.text.sentencepiece import ( - SentencePieceTokenizer as SentencePieceTokenizer, -) -from fairseq2.data.text.sentencepiece import ( - default_basic_sentencepiece_tokenizer_loader as default_basic_sentencepiece_tokenizer_loader, -) -from fairseq2.data.text.sentencepiece import ( - default_raw_sentencepiece_tokenizer_loader as default_raw_sentencepiece_tokenizer_loader, -) -from fairseq2.data.text.sentencepiece import ( - vocab_info_from_sentencepiece as vocab_info_from_sentencepiece, -) from fairseq2.data.text.text_reader import LineEnding as LineEnding from fairseq2.data.text.text_reader import read_text as read_text -from fairseq2.data.text.text_tokenizer import ( - AbstractTextTokenizer as AbstractTextTokenizer, -) -from fairseq2.data.text.text_tokenizer import ( - AbstractTextTokenizerLoader as AbstractTextTokenizerLoader, -) -from fairseq2.data.text.text_tokenizer import ( - DelegatingTextTokenizerLoader as DelegatingTextTokenizerLoader, -) -from fairseq2.data.text.text_tokenizer import TextTokenDecoder as TextTokenDecoder -from fairseq2.data.text.text_tokenizer import TextTokenEncoder as TextTokenEncoder -from fairseq2.data.text.text_tokenizer import TextTokenizer as TextTokenizer -from fairseq2.data.text.text_tokenizer import TextTokenizerLoader as TextTokenizerLoader -from fairseq2.data.text.text_tokenizer import ( - get_tokenizer_family as get_tokenizer_family, -) -from fairseq2.data.text.text_tokenizer import is_tokenizer_card as is_tokenizer_card -from fairseq2.data.text.text_tokenizer import load_text_tokenizer as load_text_tokenizer -from fairseq2.data.text.tiktoken import TiktokenDecoder as TiktokenDecoder -from fairseq2.data.text.tiktoken import TiktokenEncoder as TiktokenEncoder -from fairseq2.data.text.tiktoken import TiktokenTokenizer as TiktokenTokenizer +from fairseq2.data.text.tokenizers import AbstractTextTokenizer as AbstractTextTokenizer +from fairseq2.data.text.tokenizers import TextTokenDecoder as TextTokenDecoder +from fairseq2.data.text.tokenizers import TextTokenEncoder as TextTokenEncoder +from fairseq2.data.text.tokenizers import TextTokenizer as TextTokenizer +from fairseq2.data.text.tokenizers import load_text_tokenizer as load_text_tokenizer diff --git a/src/fairseq2/data/text/converters.py b/src/fairseq2/data/text/converters.py index 90a1cd15c..a7b4dd174 100644 --- a/src/fairseq2/data/text/converters.py +++ b/src/fairseq2/data/text/converters.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union, final +from collections.abc import Sequence +from typing import TYPE_CHECKING, final from fairseq2n import DOC_MODE from torch import Tensor @@ -27,14 +28,22 @@ class StrSplitter: Will create dictionaries object with one entry per column :param indices: - The indices of the column to keep. + A list of indices of the column to keep, or a single index. + If a single index is provided and ``exclude`` is ``False``, + the output is a string. + + :param exclude: + If ``True``, the indices will be excluded from the output, + instead of kept. Default to ``False``. Example usage:: # read all columns: ["Go.", "Va !", "CC-BY 2.0 (France)"] dataloader = read_text("tatoeba.tsv").map(StrSplitter()).and_return() - # keep only the second column and convert to string: "Va !" - dataloader = read_text("tatoeba.tsv").map(StrSplitter(indices=[1])).map(lambda x: x[0]).and_return() + # keep first and second columns, yielding the list: ["Go.", "Va !"] + dataloader = read_text("tatoeba.tsv").map(StrSplitter(indices=[0, 1])).and_return() + # keep only the second column, directly yielding a string: "Va !" + dataloader = read_text("tatoeba.tsv").map(StrSplitter(indices=1)).and_return() # keep only the first and second column and convert to dict: {"en": "Go.", "fr": "Va !"} dataloader = read_text("tatoeba.tsv").map(StrSplitter(names=["en", "fr"], indices=[0, 1])).and_return() @@ -43,13 +52,13 @@ class StrSplitter: def __init__( self, sep: str = "\t", - names: Optional[Sequence[str]] = None, - indices: Optional[Sequence[int]] = None, + names: Sequence[str] | None = None, + indices: int | Sequence[int] | None = None, exclude: bool = False, ) -> None: ... - def __call__(self, s: str) -> Union[List[str], Dict[str, str]]: + def __call__(self, s: str) -> str | list[str] | dict[str, str]: ... @final @@ -66,8 +75,8 @@ def __call__(self, s: str) -> int: class StrToTensorConverter: def __init__( self, - size: Optional[Sequence[int]] = None, - dtype: Optional[DataType] = None, + size: Sequence[int] | None = None, + dtype: DataType | None = None, ) -> None: ... diff --git a/src/fairseq2/data/text/text_reader.py b/src/fairseq2/data/text/text_reader.py index 2df08338c..6c5c69cc5 100644 --- a/src/fairseq2/data/text/text_reader.py +++ b/src/fairseq2/data/text/text_reader.py @@ -8,7 +8,7 @@ from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from fairseq2n import DOC_MODE @@ -23,14 +23,14 @@ class LineEnding(Enum): def read_text( path: Path, - key: Optional[str] = None, - encoding: Optional[str] = None, + key: str | None = None, + encoding: str | None = None, line_ending: LineEnding = LineEnding.INFER, ltrim: bool = False, rtrim: bool = False, skip_empty: bool = False, memory_map: bool = False, - block_size: Optional[int] = None, + block_size: int | None = None, ) -> DataPipelineBuilder: """Open a text file and return a data pipeline reading lines one by one.""" ... diff --git a/src/fairseq2/data/text/text_tokenizer.py b/src/fairseq2/data/text/text_tokenizer.py deleted file mode 100644 index 668d8a726..000000000 --- a/src/fairseq2/data/text/text_tokenizer.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Dict, List, Optional, Protocol, Sequence, TypeVar, Union, final - -from torch import Tensor - -from fairseq2.assets import ( - AssetCard, - AssetCardError, - AssetDownloadManager, - AssetError, - AssetStore, - default_asset_store, - default_download_manager, -) -from fairseq2.data.vocabulary_info import VocabularyInfo -from fairseq2.typing import Device, override - - -class TextTokenizer(ABC): - """Represents a tokenizer to encode and decode text.""" - - @abstractmethod - def create_encoder( - self, - *, - task: Optional[str] = None, - lang: Optional[str] = None, - mode: Optional[str] = None, - device: Optional[Device] = None, - pin_memory: bool = False, - ) -> TextTokenEncoder: - """Create a token encoder. - - The valid arguments for the ``task``, ``lang``, and ``mode`` parameters - are implementation specific. Refer to concrete ``TextTokenizer`` - subclasses for more information. - - :param task: - The task for which to generate token indices. Typically, ``task`` is - used to distinguish between different tasks such as 'translation' or - 'transcription'. - :param lang: - The language of generated token indices. Typically, multilingual - translation tasks use ``lang`` to distinguish between different - languages such as 'en-US' or 'de-DE'. - :param mode: - The mode in which to generate token indices. Typically, translation - tasks use ``mode`` to distinguish between different modes such as - 'source' or 'target'. - :param device: - The device on which to construct tensors. - :param pin_memory: - If ``True``, uses pinned memory while constructing tensors. - """ - - @abstractmethod - def create_raw_encoder( - self, *, device: Optional[Device] = None, pin_memory: bool = False - ) -> TextTokenEncoder: - """Create a raw token encoder with no control symbols. - - :param device: - The device on which to construct tensors. - :param pin_memory: - If ``True``, uses pinned memory while constructing tensors. - """ - - @abstractmethod - def create_decoder(self) -> TextTokenDecoder: - """Create a token decoder.""" - - @property - @abstractmethod - def vocab_info(self) -> VocabularyInfo: - """The vocabulary information associated with the tokenizer.""" - - -class AbstractTextTokenizer(TextTokenizer): - """Provides a skeletal implementation of :class:`TextTokenizer`.""" - - _vocab_info: VocabularyInfo - - def __init__(self, vocab_info: VocabularyInfo) -> None: - """ - :param vocab_info: - The vocabulary information associated with the tokenizer. - """ - self._vocab_info = vocab_info - - @final - @property - @override - def vocab_info(self) -> VocabularyInfo: - """The vocabulary information associated with the tokenizer.""" - return self._vocab_info - - -class TextTokenEncoder(ABC): - """Encodes text into tokens or token indices.""" - - @abstractmethod - def __call__(self, text: str) -> Tensor: - """ - :param text: - The text to encode. - """ - - @abstractmethod - def encode_as_tokens(self, text: str) -> List[str]: - """ - :param text: - The text to encode. - """ - - @property - @abstractmethod - def prefix_indices(self) -> Optional[Tensor]: - """Get the indices of the prefix tokens. *Shape:* :math:`(S)`, where - :math:`S` is the number of indices.""" - - @property - @abstractmethod - def suffix_indices(self) -> Optional[Tensor]: - """Get the indices of the suffix tokens. *Shape:* :math:`(S)`, where - :math:`S` is the number of indices.""" - - -class TextTokenDecoder(ABC): - """Decodes text from tokens or token indices.""" - - @abstractmethod - def __call__(self, token_indices: Tensor) -> str: - """ - :param token_indices: - The token indices to decode from. - """ - - @abstractmethod - def decode_from_tokens(self, tokens: Sequence[str]) -> str: - """ - :param tokens: - The tokens to decode from. - """ - - -TextTokenizerT = TypeVar("TextTokenizerT", bound=TextTokenizer) - -TextTokenizerT_co = TypeVar("TextTokenizerT_co", bound=TextTokenizer, covariant=True) - - -class TextTokenizerLoader(Protocol[TextTokenizerT_co]): - """Loads text tokenizers of type ``TextTokenizerT``.""" - - def __call__( - self, - tokenizer_name_or_card: Union[str, AssetCard], - *, - force: bool = False, - progress: bool = True, - ) -> TextTokenizerT_co: - """ - :param tokenizer_name_or_card: - The name or the asset card of the tokenizer to load. - :param force: - If ``True``, downloads the tokenizer even if it is already in cache. - :param progress: - If ``True``, displays a progress bar to stderr. - """ - - -class AbstractTextTokenizerLoader(ABC, TextTokenizerLoader[TextTokenizerT]): - """Provides a skeletal implementation of :class:`TextTokenizerLoader`.""" - - _asset_store: AssetStore - _download_manager: AssetDownloadManager - - def __init__( - self, - *, - asset_store: Optional[AssetStore] = None, - download_manager: Optional[AssetDownloadManager] = None, - ) -> None: - """ - :param asset_store: - The asset store where to check for available tokenizers. If ``None``, - the default asset store will be used. - :param download_manager: - The download manager. If ``None``, the default download manager will - be used. - """ - self._asset_store = asset_store or default_asset_store - self._download_manager = download_manager or default_download_manager - - @final - def __call__( - self, - tokenizer_name_or_card: Union[str, AssetCard], - *, - force: bool = False, - progress: bool = True, - ) -> TextTokenizerT: - if isinstance(tokenizer_name_or_card, AssetCard): - card = tokenizer_name_or_card - else: - card = self._asset_store.retrieve_card(tokenizer_name_or_card) - - tokenizer_ref = card.field("tokenizer_ref").get_as_(str) - if tokenizer_ref is not None: - return self(tokenizer_ref, force=force, progress=progress) - - tokenizer_uri = card.field("tokenizer").as_uri() - - try: - path = self._download_manager.download_tokenizer( - tokenizer_uri, card.name, force=force, progress=progress - ) - except ValueError as ex: - raise AssetCardError( - f"The value of the field 'tokenizer' of the asset card '{card.name}' must be a URI. See nested exception for details." - ) from ex - - try: - return self._load(path, card) - except ValueError as ex: - raise AssetError( - f"The {card.name} tokenizer cannot be loaded. See nested exception for details." - ) from ex - - @abstractmethod - def _load(self, path: Path, card: AssetCard) -> TextTokenizerT: - """ - :param path: - The path to the tokenizer. - :param card: - The asset card of the tokenizer. - """ - - -@final -class DelegatingTextTokenizerLoader(TextTokenizerLoader[TextTokenizerT]): - """Loads text tokenizers of type ``TextTokenizerT`` using registered loaders.""" - - _asset_store: AssetStore - _loaders: Dict[str, TextTokenizerLoader[TextTokenizerT]] - - def __init__(self, *, asset_store: Optional[AssetStore] = None) -> None: - """ - :param asset_store: - The asset store where to check for available tokenizers. If ``None``, - the default asset store will be used. - """ - self._asset_store = asset_store or default_asset_store - - self._loaders = {} - - def __call__( - self, - tokenizer_name_or_card: Union[str, AssetCard], - *, - force: bool = False, - progress: bool = True, - ) -> TextTokenizerT: - if isinstance(tokenizer_name_or_card, AssetCard): - card = tokenizer_name_or_card - else: - card = self._asset_store.retrieve_card(tokenizer_name_or_card) - - ref = card.field("tokenizer_ref").get_as_(str) - if ref is not None: - return self(ref, force=force, progress=progress) - - family = card.field("tokenizer_family").as_(str) - - try: - loader = self._loaders[family] - except KeyError: - raise AssetError( - f"The value of the field 'tokenizer_family' of the asset card '{card.name}' must be a supported tokenizer family, but '{family}' has no registered loader." - ) from None - - return loader(card, force=force, progress=progress) - - def register( - self, family: str, loader: TextTokenizerLoader[TextTokenizerT] - ) -> None: - """Register a tokenizer loader to use with this loader. - - :param family: - The tokenizer family. If the 'tokenizer_family', 'model_family', or - 'dataset_family' field of an asset card matches this value, the - specified ``loader`` will be used. - :param loader: - The tokenizer loader. - """ - if family in self._loaders: - raise ValueError( - f"`family` must be a unique text tokenizer family name, but '{family}' has already a registered loader." - ) - - self._loaders[family] = loader - - def supports(self, tokenizer_name_or_card: Union[str, AssetCard]) -> bool: - """Return ``True`` if the specified tokenizer has a registered loader.""" - if isinstance(tokenizer_name_or_card, AssetCard): - card = tokenizer_name_or_card - else: - card = self._asset_store.retrieve_card(tokenizer_name_or_card) - - ref = card.field("tokenizer_ref").get_as_(str) - if ref is not None: - return self.supports(ref) - - family = card.field("tokenizer_family").as_(str) - - return family in self._loaders - - -load_text_tokenizer = DelegatingTextTokenizerLoader[TextTokenizer]() - - -def is_tokenizer_card(card: AssetCard) -> bool: - """Return ``True`` if ``card`` specifies a tokenizer.""" - return card.field("tokenizer_family").exists() - - -def get_tokenizer_family(card: AssetCard) -> str: - """Return the tokenizer family name contained in ``card``.""" - return card.field("tokenizer_family").as_(str) # type: ignore[no-any-return] diff --git a/src/fairseq2/data/text/tokenizers/__init__.py b/src/fairseq2/data/text/tokenizers/__init__.py new file mode 100644 index 000000000..eb2f218a3 --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.data.text.tokenizers.handler import ( + StandardTextTokenizerHandler as StandardTextTokenizerHandler, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerHandler as TextTokenizerHandler, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerLoader as TextTokenizerLoader, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerNotFoundError as TextTokenizerNotFoundError, +) +from fairseq2.data.text.tokenizers.handler import ( + get_text_tokenizer_family as get_text_tokenizer_family, +) +from fairseq2.data.text.tokenizers.ref import ( + resolve_text_tokenizer_reference as resolve_text_tokenizer_reference, +) +from fairseq2.data.text.tokenizers.static import ( + load_text_tokenizer as load_text_tokenizer, +) +from fairseq2.data.text.tokenizers.tokenizer import ( + AbstractTextTokenizer as AbstractTextTokenizer, +) +from fairseq2.data.text.tokenizers.tokenizer import TextTokenDecoder as TextTokenDecoder +from fairseq2.data.text.tokenizers.tokenizer import TextTokenEncoder as TextTokenEncoder +from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer as TextTokenizer diff --git a/src/fairseq2/data/text/tokenizers/char_tokenizer.py b/src/fairseq2/data/text/tokenizers/char_tokenizer.py new file mode 100644 index 000000000..1986ce453 --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/char_tokenizer.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Final + +from fairseq2.data.text.tokenizers.sentencepiece import load_raw_sentencepiece + +CHAR_TOKENIZER_FAMILY: Final = "char_tokenizer" + +load_char_tokenizer = load_raw_sentencepiece diff --git a/src/fairseq2/data/text/tokenizers/handler.py b/src/fairseq2/data/text/tokenizers/handler.py new file mode 100644 index 000000000..d864f238d --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/handler.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Protocol, final + +from typing_extensions import override + +from fairseq2.assets import AssetCard, AssetDownloadManager +from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer + + +class TextTokenizerHandler(ABC): + @abstractmethod + def load(self, card: AssetCard, *, force: bool = False) -> TextTokenizer: + ... + + +class TextTokenizerNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known text tokenizer.") + + self.name = name + + +class TextTokenizerLoader(Protocol): + def __call__(self, path: Path, card: AssetCard) -> TextTokenizer: + ... + + +@final +class StandardTextTokenizerHandler(TextTokenizerHandler): + _loader: TextTokenizerLoader + _asset_download_manager: AssetDownloadManager + + def __init__( + self, + *, + loader: TextTokenizerLoader, + asset_download_manager: AssetDownloadManager, + ) -> None: + self._loader = loader + self._asset_download_manager = asset_download_manager + + @override + def load(self, card: AssetCard, *, force: bool = False) -> TextTokenizer: + tokenizer_uri = card.field("tokenizer").as_uri() + + path = self._asset_download_manager.download_tokenizer( + tokenizer_uri, card.name, force=force + ) + + return self._loader(path, card) + + +def get_text_tokenizer_family(card: AssetCard) -> str: + return card.field("tokenizer_family").as_(str) # type: ignore[no-any-return] diff --git a/src/fairseq2/data/text/tokenizers/llama.py b/src/fairseq2/data/text/tokenizers/llama.py new file mode 100644 index 000000000..39475fcf3 --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/llama.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from pathlib import Path +from typing import Final, final + +from typing_extensions import override + +from fairseq2.assets import AssetCard, AssetCardError +from fairseq2.data.text.tokenizers.sentencepiece import BasicSentencePieceTokenizer +from fairseq2.data.text.tokenizers.tiktoken import TiktokenEncoder, TiktokenTokenizer +from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer +from fairseq2.typing import Device + + +@final +class LLaMA3Tokenizer(TiktokenTokenizer): + """Represents a LLaMA 3 tokenizer.""" + + _SPLIT_REGEX: Final = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # fmt: skip + + _eos_token: str + + def __init__(self, path: Path, instruct: bool = False) -> None: + """ + :param path: + The path to the tiktoken BPE file. + :param instruct: + If ``True``, uses EOT (end-of-turn) token in-place of EOS token. + """ + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end-of-message + "<|eot_id|>", # end-of-turn + "<|python_tag|>", + ] + + num_reserved_special_tokens = 256 + + for i in range(num_reserved_special_tokens - len(special_tokens)): + special_tokens.append(f"<|reserved_special_token_{2 + i}|>") + + self._eos_token = "<|eot_id|>" if instruct else "<|end_of_text|>" + + super().__init__( + path, + split_regex=self._SPLIT_REGEX, + unk_token=None, + bos_token="<|begin_of_text|>", + eos_token=self._eos_token, + pad_token="<|finetune_right_pad_id|>", + boh_token="<|start_header_id|>", + eoh_token="<|end_header_id|>", + special_tokens=special_tokens, + ) + + @override + def create_encoder( + self, + *, + task: str | None = None, + lang: str | None = None, + mode: str | None = None, + device: Device | None = None, + pin_memory: bool = False, + ) -> TiktokenEncoder: + if task is not None: + raise ValueError(f"`task` must be `None`, but is '{task}' instead.") + + if lang is not None: + raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.") + + match mode: + case None | "default": + prefix_tokens = ["<|begin_of_text|>"] + suffix_tokens = [self._eos_token] + case "prompt": + prefix_tokens = ["<|begin_of_text|>"] + # In prompt mode, we expect the generator to finish the sequence. + suffix_tokens = [] + case "prompt_response": + prefix_tokens = [] + suffix_tokens = [self._eos_token] + case "as_is": + prefix_tokens = [] + suffix_tokens = [] + case _: + raise ValueError( + f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response, as_is" + ) + + return TiktokenEncoder( + self._encoding, + prefix_tokens=prefix_tokens, + suffix_tokens=suffix_tokens, + device=device, + pin_memory=pin_memory, + ) + + +LLAMA_TOKENIZER_FAMILY: Final = "llama" + + +def load_llama_tokenizer(path: Path, card: AssetCard) -> TextTokenizer: + use_v2 = card.field("use_v2_tokenizer").get_as_(bool, False) + if use_v2: + field = card.field("model_config").field("vocab_info").field("eos_idx") + + eos_idx = field.get_as_(int) + + eot_idx = 128_009 # end-of-turn + + try: + return LLaMA3Tokenizer(path, instruct=eos_idx == eot_idx) + except ValueError as ex: + raise AssetCardError( + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip + ) from ex + else: + try: + return BasicSentencePieceTokenizer(path) + except ValueError as ex: + raise AssetCardError( + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip + ) from ex diff --git a/src/fairseq2/data/text/tokenizers/mistral.py b/src/fairseq2/data/text/tokenizers/mistral.py new file mode 100644 index 000000000..299fd163d --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/mistral.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Final + +from fairseq2.data.text.tokenizers.sentencepiece import load_basic_sentencepiece + +MISTRAL_TOKENIZER_FAMILY: Final = "mistal" + +load_mistral_tokenizer = load_basic_sentencepiece diff --git a/src/fairseq2/models/nllb/tokenizer.py b/src/fairseq2/data/text/tokenizers/nllb.py similarity index 56% rename from src/fairseq2/models/nllb/tokenizer.py rename to src/fairseq2/data/text/tokenizers/nllb.py index ca8d15a10..0d3f6abee 100644 --- a/src/fairseq2/models/nllb/tokenizer.py +++ b/src/fairseq2/data/text/tokenizers/nllb.py @@ -6,18 +6,25 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence, Set, final +from typing import Final, final -from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizer -from fairseq2.typing import Device, override +from typing_extensions import override + +from fairseq2.assets import AssetCard, AssetCardError +from fairseq2.data.text.tokenizers.sentencepiece import ( + SentencePieceEncoder, + SentencePieceTokenizer, +) +from fairseq2.typing import Device @final class NllbTokenizer(SentencePieceTokenizer): """Represents an NLLB tokenizer.""" - _langs: Set[str] + _langs: set[str] _default_lang: str def __init__(self, path: Path, langs: Sequence[str], default_lang: str) -> None: @@ -50,13 +57,13 @@ def __init__(self, path: Path, langs: Sequence[str], default_lang: str) -> None: def create_encoder( self, *, - task: Optional[str] = None, - lang: Optional[str] = None, - mode: Optional[str] = None, - device: Optional[Device] = None, + task: str | None = None, + lang: str | None = None, + mode: str | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> SentencePieceEncoder: - """Create a token encoder. + """Constructs a token encoder. :param task: Must be 'translation'. If ``None``, defaults to 'translation'. @@ -83,29 +90,30 @@ def create_encoder( f"`lang` must be a supported language, but is '{lang}' instead." ) - if mode is None or mode == "source": - # NLLB models expect a language token in place of BOS in source - # sequences. - prefix_tokens = [f"__{lang}__"] - suffix_tokens = [""] - elif mode == "source_mining": - prefix_tokens = [f"__{lang}__", ""] - suffix_tokens = [""] - elif mode == "source_mmt_bt": - prefix_tokens = [f"__{lang}__", ""] - suffix_tokens = [""] - elif mode == "source_smt_bt": - prefix_tokens = [f"__{lang}__", ""] - suffix_tokens = [""] - elif mode == "target": - # Target sequences are expected to start with an EOS, followed by - # the language token. - prefix_tokens = ["", f"__{lang}__"] - suffix_tokens = [""] - else: - raise ValueError( - f"`mode` must be 'source' or 'target', but is '{mode}' instead." - ) + match mode: + case None | "source": + # NLLB models expect a language token in place of BOS in source + # sequences. + prefix_tokens = [f"__{lang}__"] + suffix_tokens = [""] + case "source_mining": + prefix_tokens = [f"__{lang}__", ""] + suffix_tokens = [""] + case "source_mmt_bt": + prefix_tokens = [f"__{lang}__", ""] + suffix_tokens = [""] + case "source_smt_bt": + prefix_tokens = [f"__{lang}__", ""] + suffix_tokens = [""] + case "target": + # Target sequences are expected to start with an EOS, followed by + # the language token. + prefix_tokens = ["", f"__{lang}__"] + suffix_tokens = [""] + case _: + raise ValueError( + f"`mode` must be 'source' or 'target', but is '{mode}' instead." + ) return SentencePieceEncoder( self._model, @@ -114,3 +122,19 @@ def create_encoder( device=device, pin_memory=pin_memory, ) + + +NLLB_TOKENIZER_FAMILY: Final = "nllb" + + +def load_nllb_tokenizer(path: Path, card: AssetCard) -> NllbTokenizer: + langs = card.field("langs").as_(list[str]) + + default_lang = card.field("default_lang").as_(str) + + try: + return NllbTokenizer(path, langs, default_lang) + except ValueError as ex: + raise AssetCardError( + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip + ) from ex diff --git a/src/fairseq2/data/text/tokenizers/ref.py b/src/fairseq2/data/text/tokenizers/ref.py new file mode 100644 index 000000000..b9b9fe856 --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/ref.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.assets import AssetCard, AssetStore + + +def resolve_text_tokenizer_reference( + asset_store: AssetStore, card: AssetCard +) -> AssetCard: + while True: + ref_name = card.field("tokenizer_ref").get_as_(str) + if ref_name is None: + break + + card = asset_store.retrieve_card(ref_name) + + return card diff --git a/src/fairseq2/models/s2t_transformer/tokenizer.py b/src/fairseq2/data/text/tokenizers/s2t_transformer.py similarity index 70% rename from src/fairseq2/models/s2t_transformer/tokenizer.py rename to src/fairseq2/data/text/tokenizers/s2t_transformer.py index 32efe3777..aaf9165f8 100644 --- a/src/fairseq2/models/s2t_transformer/tokenizer.py +++ b/src/fairseq2/data/text/tokenizers/s2t_transformer.py @@ -7,10 +7,16 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, Set, final +from typing import Final, final -from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizer -from fairseq2.typing import Device, override +from typing_extensions import override + +from fairseq2.assets import AssetCard, AssetCardError +from fairseq2.data.text.tokenizers.sentencepiece import ( + SentencePieceEncoder, + SentencePieceTokenizer, +) +from fairseq2.typing import Device @final @@ -18,14 +24,14 @@ class S2TTransformerTokenizer(SentencePieceTokenizer): """Represents an S2T Transformer tokenizer.""" _task: str - _target_langs: Set[str] + _target_langs: set[str] _default_target_lang: str def __init__( self, path: Path, task: str, - target_langs: Set[str], + target_langs: set[str], default_target_lang: str, ) -> None: """ @@ -54,13 +60,13 @@ def __init__( def create_encoder( self, *, - task: Optional[str] = None, - lang: Optional[str] = None, - mode: Optional[str] = None, - device: Optional[Device] = None, + task: str | None = None, + lang: str | None = None, + mode: str | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> SentencePieceEncoder: - """Create a token encoder. + """Constructs a token encoder. :param task: Must match :attr:`task`. If ``None``, defaults to :attr:`task`. @@ -100,3 +106,25 @@ def create_encoder( device=device, pin_memory=pin_memory, ) + + +S2T_TRANSFORMER_TOKENIZER_FAMILY: Final = "s2t_transformer" + + +def load_s2t_transformer_tokenizer( + path: Path, card: AssetCard +) -> S2TTransformerTokenizer: + valid_tasks = {"translation", "transcription"} + + task = card.field("task").as_one_of(valid_tasks) + + target_langs = card.field("target_langs").as_(list[str]) + + try: + return S2TTransformerTokenizer( + path, task, set(target_langs), default_target_lang=target_langs[0] + ) + except ValueError as ex: + raise AssetCardError( + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip + ) from ex diff --git a/src/fairseq2/data/text/sentencepiece.py b/src/fairseq2/data/text/tokenizers/sentencepiece.py similarity index 70% rename from src/fairseq2/data/text/sentencepiece.py rename to src/fairseq2/data/text/tokenizers/sentencepiece.py index ae44dc15a..6857dbd02 100644 --- a/src/fairseq2/data/text/sentencepiece.py +++ b/src/fairseq2/data/text/tokenizers/sentencepiece.py @@ -6,28 +6,29 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Sequence, final +from typing import TYPE_CHECKING, final from fairseq2n import DOC_MODE from torch import Tensor +from typing_extensions import override from fairseq2.assets import AssetCard -from fairseq2.data.text.text_tokenizer import ( +from fairseq2.data.text.tokenizers.tokenizer import ( AbstractTextTokenizer, - AbstractTextTokenizerLoader, TextTokenDecoder, TextTokenEncoder, ) from fairseq2.data.vocabulary_info import VocabularyInfo -from fairseq2.typing import Device, override +from fairseq2.typing import Device if TYPE_CHECKING or DOC_MODE: @final class SentencePieceModel: def __init__( - self, path: Path, control_symbols: Optional[Sequence[str]] = None + self, path: Path, control_symbols: Sequence[str] | None = None ) -> None: ... @@ -38,19 +39,19 @@ def index_to_token(self, idx: int) -> str: ... @property - def unk_idx(self) -> Optional[int]: + def unk_idx(self) -> int | None: ... @property - def bos_idx(self) -> Optional[int]: + def bos_idx(self) -> int | None: ... @property - def eos_idx(self) -> Optional[int]: + def eos_idx(self) -> int | None: ... @property - def pad_idx(self) -> Optional[int]: + def pad_idx(self) -> int | None: ... @property @@ -62,13 +63,13 @@ class SentencePieceEncoder(TextTokenEncoder): def __init__( self, model: SentencePieceModel, - prefix_tokens: Optional[Sequence[str]] = None, - suffix_tokens: Optional[Sequence[str]] = None, + prefix_tokens: Sequence[str] | None = None, + suffix_tokens: Sequence[str] | None = None, reverse: bool = False, enable_sampling: bool = False, nbest_size: int = -1, alpha: float = 0.1, - device: Optional[Device] = None, + device: Device | None = None, pin_memory: bool = False, ) -> None: ... @@ -78,17 +79,17 @@ def __call__(self, text: str) -> Tensor: ... @override - def encode_as_tokens(self, text: str) -> List[str]: + def encode_as_tokens(self, text: str) -> list[str]: ... @property @override - def prefix_indices(self) -> Optional[Tensor]: + def prefix_indices(self) -> Tensor | None: ... @property @override - def suffix_indices(self) -> Optional[Tensor]: + def suffix_indices(self) -> Tensor | None: ... @final @@ -133,7 +134,7 @@ class SentencePieceTokenizer(AbstractTextTokenizer): _model: SentencePieceModel def __init__( - self, path: Path, control_symbols: Optional[Sequence[str]] = None + self, path: Path, control_symbols: Sequence[str] | None = None ) -> None: """ :param path: @@ -149,7 +150,7 @@ def __init__( @override def create_raw_encoder( - self, *, device: Optional[Device] = None, pin_memory: bool = False + self, *, device: Device | None = None, pin_memory: bool = False ) -> SentencePieceEncoder: return SentencePieceEncoder(self._model, device=device, pin_memory=pin_memory) @@ -178,13 +179,13 @@ def __init__(self, path: Path) -> None: def create_encoder( self, *, - task: Optional[str] = None, - lang: Optional[str] = None, - mode: Optional[str] = None, - device: Optional[Device] = None, + task: str | None = None, + lang: str | None = None, + mode: str | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> SentencePieceEncoder: - """Create a token encoder. + """Constructs a token encoder. :param task: Must be ``None``. @@ -204,20 +205,21 @@ def create_encoder( if lang is not None: raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.") - if mode is None or mode == "default": - prefix_tokens = [""] - suffix_tokens = [""] - elif mode == "prompt": - prefix_tokens = [""] - # In prompt mode, we expect the generator to finish the sequence. - suffix_tokens = None - elif mode == "prompt_response": - prefix_tokens = [] - suffix_tokens = [""] - else: - raise ValueError( - f"`mode` must be 'default' or 'prompt', but is '{mode}' instead." - ) + match mode: + case None | "default": + prefix_tokens = [""] + suffix_tokens = [""] + case "prompt": + prefix_tokens = [""] + # In prompt mode, we expect the generator to finish the sequence. + suffix_tokens = [] + case "prompt_response": + prefix_tokens = [] + suffix_tokens = [""] + case _: + raise ValueError( + f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response" + ) return SentencePieceEncoder( self._model, @@ -228,18 +230,8 @@ def create_encoder( ) -@final -class BasicSentencePieceTokenizerLoader( - AbstractTextTokenizerLoader[BasicSentencePieceTokenizer] -): - """Loads tokenizers of type :class:`BasicSentencePieceTokenizer`.""" - - @override - def _load(self, path: Path, card: AssetCard) -> BasicSentencePieceTokenizer: - return BasicSentencePieceTokenizer(path) - - -default_basic_sentencepiece_tokenizer_loader = BasicSentencePieceTokenizerLoader() +def load_basic_sentencepiece(path: Path, card: AssetCard) -> SentencePieceTokenizer: + return BasicSentencePieceTokenizer(path) @final @@ -257,13 +249,13 @@ def __init__(self, path: Path) -> None: def create_encoder( self, *, - task: Optional[str] = None, - lang: Optional[str] = None, - mode: Optional[str] = None, - device: Optional[Device] = None, + task: str | None = None, + lang: str | None = None, + mode: str | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> SentencePieceEncoder: - """Create a token encoder. + """Constructs a token encoder. :param task: Must be ``None``. @@ -288,18 +280,8 @@ def create_encoder( return self.create_raw_encoder(device=device, pin_memory=pin_memory) -@final -class RawSentencePieceTokenizerLoader( - AbstractTextTokenizerLoader[RawSentencePieceTokenizer] -): - """Loads tokenizers of type :class:`RawSentencePieceTokenizer`.""" - - @override - def _load(self, path: Path, card: AssetCard) -> RawSentencePieceTokenizer: - return RawSentencePieceTokenizer(path) - - -default_raw_sentencepiece_tokenizer_loader = RawSentencePieceTokenizerLoader() +def load_raw_sentencepiece(path: Path, card: AssetCard) -> SentencePieceTokenizer: + return RawSentencePieceTokenizer(path) def vocab_info_from_sentencepiece(model: SentencePieceModel) -> VocabularyInfo: diff --git a/src/fairseq2/data/text/tokenizers/static.py b/src/fairseq2/data/text/tokenizers/static.py new file mode 100644 index 000000000..d157993fe --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/static.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.assets import AssetCard +from fairseq2.context import get_runtime_context +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerHandler, + TextTokenizerNotFoundError, + get_text_tokenizer_family, +) +from fairseq2.data.text.tokenizers.ref import resolve_text_tokenizer_reference +from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer + + +def load_text_tokenizer( + name_or_card: str | AssetCard, *, force: bool = False +) -> TextTokenizer: + context = get_runtime_context() + + if isinstance(name_or_card, AssetCard): + card = name_or_card + else: + card = context.asset_store.retrieve_card(name_or_card) + + card = resolve_text_tokenizer_reference(context.asset_store, card) + + family = get_text_tokenizer_family(card) + + registry = context.get_registry(TextTokenizerHandler) + + try: + handler = registry.get(family) + except LookupError: + raise TextTokenizerNotFoundError(card.name) from None + + return handler.load(card, force=force) diff --git a/src/fairseq2/data/text/tiktoken.py b/src/fairseq2/data/text/tokenizers/tiktoken.py similarity index 75% rename from src/fairseq2/data/text/tiktoken.py rename to src/fairseq2/data/text/tokenizers/tiktoken.py index 47c866582..cf83b80d4 100644 --- a/src/fairseq2/data/text/tiktoken.py +++ b/src/fairseq2/data/text/tokenizers/tiktoken.py @@ -6,21 +6,23 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path -from typing import List, Optional, Sequence, final +from typing import final import torch from tiktoken import Encoding from tiktoken.load import load_tiktoken_bpe from torch import Tensor +from typing_extensions import override -from fairseq2.data.text.text_tokenizer import ( +from fairseq2.data.text.tokenizers.tokenizer import ( AbstractTextTokenizer, TextTokenDecoder, TextTokenEncoder, ) from fairseq2.data.vocabulary_info import VocabularyInfo -from fairseq2.typing import Device, override +from fairseq2.typing import Device class TiktokenTokenizer(AbstractTextTokenizer): @@ -34,27 +36,26 @@ def __init__( path: Path, split_regex: str, *, - unk_token: Optional[str] = None, - bos_token: Optional[str] = None, - eos_token: Optional[str] = None, - pad_token: Optional[str] = None, - special_tokens: Optional[Sequence[str]] = None, + unk_token: str | None = None, + bos_token: str | None = None, + eos_token: str | None = None, + pad_token: str | None = None, + boh_token: str | None = None, + eoh_token: str | None = None, + special_tokens: Sequence[str] | None = None, ) -> None: """ - :param path: - The path to the tiktoken BPE file. - :param split_regex: - The regex pattern string that is used to split the input text. - :param unk_token: - The token that represents an unknown element (UNK). - :param bos_token: - The token that represents the beginning of a sequence (BOS). - :param eos_token: - The token that represents the end of of a sequence (EOS). - :param pad_token: - The token that is used to pad a sequence (PAD). - :param special_tokens: - The extra special tokens to include in the tokenizer. + :param path: The path to the tiktoken BPE file. + :param split_regex: The regex pattern string that is used to split the + input text. + :param unk_token: The token that represents an unknown element. + :param bos_token: The token that represents the beginning of a sequence. + :param eos_token: The token that represents the end of a sequence. + :param pad_token: The token that is used to pad a sequence. + :param boh_token: The token that represents the beginning of a header. + :param eoh_token: The token that represents the end of a header. + :param special_tokens: The extra special tokens to include in the + tokenizer. """ tokens = load_tiktoken_bpe(str(path)) @@ -76,7 +77,7 @@ def __init__( special_tokens=special_token_map, ) - def maybe_index(token: Optional[str]) -> Optional[int]: + def maybe_index(token: str | None) -> int | None: if token: return self._encoding.encode_single_token(token) @@ -88,13 +89,15 @@ def maybe_index(token: Optional[str]) -> Optional[int]: bos_idx=maybe_index(bos_token), eos_idx=maybe_index(eos_token), pad_idx=maybe_index(pad_token), + boh_idx=maybe_index(boh_token), + eoh_idx=maybe_index(eoh_token), ) super().__init__(vocab_info) @override def create_raw_encoder( - self, *, device: Optional[Device] = None, pin_memory: bool = False + self, *, device: Device | None = None, pin_memory: bool = False ) -> TiktokenEncoder: return TiktokenEncoder(self._encoding, device=device, pin_memory=pin_memory) @@ -102,31 +105,25 @@ def create_raw_encoder( def create_decoder(self) -> TiktokenDecoder: return TiktokenDecoder(self._encoding, self._num_bpe_tokens) - @final - @property - def encoding(self) -> Encoding: - """The tiktoken :class:`Encoding` object.""" - return self._encoding - @final class TiktokenEncoder(TextTokenEncoder): """Represents a tiktoken decoder.""" - _prefix_indices: List[int] - _suffix_indices: List[int] - _prefix_index_tensor: Optional[Tensor] - _suffix_index_tensor: Optional[Tensor] - _device: Optional[Device] + _prefix_indices: list[int] + _suffix_indices: list[int] + _prefix_index_tensor: Tensor | None + _suffix_index_tensor: Tensor | None + _device: Device | None _pin_memory: bool def __init__( self, encoding: Encoding, *, - prefix_tokens: Optional[Sequence[str]] = None, - suffix_tokens: Optional[Sequence[str]] = None, - device: Optional[Device] = None, + prefix_tokens: Sequence[str] | None = None, + suffix_tokens: Sequence[str] | None = None, + device: Device | None = None, pin_memory: bool = False, ) -> None: """ @@ -189,7 +186,7 @@ def __call__(self, text: str) -> Tensor: ) @override - def encode_as_tokens(self, text: str) -> List[str]: + def encode_as_tokens(self, text: str) -> list[str]: indices = self(text).tolist() b = self._encoding.decode_tokens_bytes(indices) @@ -198,12 +195,12 @@ def encode_as_tokens(self, text: str) -> List[str]: @property @override - def prefix_indices(self) -> Optional[Tensor]: + def prefix_indices(self) -> Tensor | None: return self._prefix_index_tensor @property @override - def suffix_indices(self) -> Optional[Tensor]: + def suffix_indices(self) -> Tensor | None: return self._suffix_index_tensor diff --git a/src/fairseq2/data/text/tokenizers/tokenizer.py b/src/fairseq2/data/text/tokenizers/tokenizer.py new file mode 100644 index 000000000..bd482281f --- /dev/null +++ b/src/fairseq2/data/text/tokenizers/tokenizer.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import final + +from torch import Tensor +from typing_extensions import override + +from fairseq2.data.vocabulary_info import VocabularyInfo +from fairseq2.typing import Device + + +class TextTokenizer(ABC): + """Represents a tokenizer to encode and decode text.""" + + @abstractmethod + def create_encoder( + self, + *, + task: str | None = None, + lang: str | None = None, + mode: str | None = None, + device: Device | None = None, + pin_memory: bool = False, + ) -> TextTokenEncoder: + """Constructs a token encoder. + + The valid arguments for the ``task``, ``lang``, and ``mode`` parameters + are implementation specific. Refer to concrete ``TextTokenizer`` + subclasses for more information. + + :param task: + The task for which to generate token indices. Typically, ``task`` is + used to distinguish between different tasks such as 'translation' or + 'transcription'. + :param lang: + The language of generated token indices. Typically, multilingual + translation tasks use ``lang`` to distinguish between different + languages such as 'en-US' or 'de-DE'. + :param mode: + The mode in which to generate token indices. Typically, translation + tasks use ``mode`` to distinguish between different modes such as + 'source' or 'target'. + :param device: + The device on which to construct tensors. + :param pin_memory: + If ``True``, uses pinned memory while constructing tensors. + """ + + @abstractmethod + def create_raw_encoder( + self, *, device: Device | None = None, pin_memory: bool = False + ) -> TextTokenEncoder: + """Constructs a raw token encoder with no control symbols. + + :param device: + The device on which to construct tensors. + :param pin_memory: + If ``True``, uses pinned memory while constructing tensors. + """ + + @abstractmethod + def create_decoder(self) -> TextTokenDecoder: + """Constructs a token decoder.""" + + @property + @abstractmethod + def vocab_info(self) -> VocabularyInfo: + """The vocabulary information associated with the tokenizer.""" + + +class AbstractTextTokenizer(TextTokenizer): + """Provides a skeletal implementation of :class:`TextTokenizer`.""" + + _vocab_info: VocabularyInfo + + def __init__(self, vocab_info: VocabularyInfo) -> None: + """ + :param vocab_info: + The vocabulary information associated with the tokenizer. + """ + self._vocab_info = vocab_info + + @final + @property + @override + def vocab_info(self) -> VocabularyInfo: + """The vocabulary information associated with the tokenizer.""" + return self._vocab_info + + +class TextTokenEncoder(ABC): + """Encodes text into tokens or token indices.""" + + @abstractmethod + def __call__(self, text: str) -> Tensor: + """ + :param text: + The text to encode. + """ + + @abstractmethod + def encode_as_tokens(self, text: str) -> list[str]: + """ + :param text: + The text to encode. + """ + + @property + @abstractmethod + def prefix_indices(self) -> Tensor | None: + """Get the indices of the prefix tokens. *Shape:* :math:`(S)`, where + :math:`S` is the number of indices.""" + + @property + @abstractmethod + def suffix_indices(self) -> Tensor | None: + """Get the indices of the suffix tokens. *Shape:* :math:`(S)`, where + :math:`S` is the number of indices.""" + + +class TextTokenDecoder(ABC): + """Decodes text from tokens or token indices.""" + + @abstractmethod + def __call__(self, token_indices: Tensor) -> str: + """ + :param token_indices: + The token indices to decode from. + """ + + @abstractmethod + def decode_from_tokens(self, tokens: Sequence[str]) -> str: + """ + :param tokens: + The tokens to decode from. + """ diff --git a/src/fairseq2/data/utils.py b/src/fairseq2/data/utils.py new file mode 100644 index 000000000..7783af577 --- /dev/null +++ b/src/fairseq2/data/utils.py @@ -0,0 +1,71 @@ +from collections.abc import Callable, Iterator +from typing import TypeVar + +from typing_extensions import Self, TypeAlias + +from fairseq2.data import DataPipelineBuilder, read_iterator + +T = TypeVar("T") + +IteratorFactory: TypeAlias = Callable[[], Iterator[T]] + + +class IteratorPickleWrapper(Iterator[T]): + def __init__(self, iterator_factory: IteratorFactory[T]) -> None: + self._iterator_factory: IteratorFactory[T] = iterator_factory + self._iterator: Iterator[T] = self._iterator_factory() + self._counter = 0 + + def __iter__(self) -> Self: + return self + + def __next__(self) -> T: + out = next(self._iterator) + self._counter += 1 + return out + + def __getstate__(self) -> tuple[IteratorFactory[T], int]: + return self._iterator_factory, self._counter + + def __setstate__(self, state: tuple[IteratorFactory[T], int]) -> None: + self._iterator_factory, counter = state + self._iterator = self._iterator_factory() + for i in range(counter): + next(self._iterator) + self._counter = counter + + +def read_pickle_wrapped_iterator( + iterator_factory: IteratorFactory[T], +) -> DataPipelineBuilder: + """Read each element of iterator generated by ``iterator_factory``. + + If ``iterator_factory`` is not pickleable, then this function wraps the + iterator in ``IteratorPickleWrapper``, a simple class that increments + an internal ``_counter`` every time ``__next__(self)`` is called. + Upon pickling, this counter is saved, and upon unpickling, a new iterator + is generated from ``iterator_factory`` and ``__next__(self)`` is called + ``counter`` many times. Note that this means the time complexity of + unpickling is linear in the number of times ``__next__(self)`` was called + prior to pickling. + + :param iterator_factory: + The iterator factory. + """ + + iterator = iterator_factory() + try: + return read_iterator( + iterator, reset_fn=lambda x: iterator_factory(), infinite=False + ) + except TypeError as e: + if ( + str(e) + != "`iterator` is not pickleable; set `skip_pickling_check` to True to bypass (see `read_iterator` documentation for details)." + ): + raise + return read_iterator( + IteratorPickleWrapper(iterator_factory), + reset_fn=lambda x: IteratorPickleWrapper(iterator_factory), + infinite=False, + ) diff --git a/src/fairseq2/data/vocabulary_info.py b/src/fairseq2/data/vocabulary_info.py index 9334900eb..df68d6aa4 100644 --- a/src/fairseq2/data/vocabulary_info.py +++ b/src/fairseq2/data/vocabulary_info.py @@ -7,10 +7,8 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, final -@final @dataclass class VocabularyInfo: """Describes the vocabulary used by a tokenizer""" @@ -18,14 +16,20 @@ class VocabularyInfo: size: int """The size of the vocabulary.""" - unk_idx: Optional[int] + unk_idx: int | None """The index of the symbol that represents an unknown element (UNK).""" - bos_idx: Optional[int] + bos_idx: int | None """The index of the symbol that represents the beginning of a sequence (BOS).""" - eos_idx: Optional[int] + eos_idx: int | None """The index of the symbol that represents the end of a sequence (EOS).""" - pad_idx: Optional[int] + pad_idx: int | None """The index of the symbol that is used to pad a sequence (PAD).""" + + boh_idx: int | None = None + """The index of the symbol that represents the beginning of a header (BOH).""" + + eoh_idx: int | None = None + """The index of the symbol that represents the end of a header (EOH).""" diff --git a/src/fairseq2/datasets/__init__.py b/src/fairseq2/datasets/__init__.py index 709f59ba7..386794bab 100644 --- a/src/fairseq2/datasets/__init__.py +++ b/src/fairseq2/datasets/__init__.py @@ -6,19 +6,18 @@ from __future__ import annotations -from fairseq2.datasets.batching import LengthBatching as LengthBatching -from fairseq2.datasets.batching import StaticBatching as StaticBatching +from fairseq2.datasets.config import Batching as Batching +from fairseq2.datasets.config import DataReadOptions as DataReadOptions +from fairseq2.datasets.config import LengthBatching as LengthBatching +from fairseq2.datasets.config import StaticBatching as StaticBatching from fairseq2.datasets.data_reader import DataPipelineReader as DataPipelineReader from fairseq2.datasets.data_reader import DataReader as DataReader +from fairseq2.datasets.data_reader import SyncMode as SyncMode +from fairseq2.datasets.error import DataReadError as DataReadError from fairseq2.datasets.error import DatasetError as DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader as AbstractDatasetLoader -from fairseq2.datasets.loader import DatasetLoader as DatasetLoader -from fairseq2.datasets.loader import DelegatingDatasetLoader as DelegatingDatasetLoader -from fairseq2.datasets.loader import get_dataset_family as get_dataset_family -from fairseq2.datasets.loader import is_dataset_card as is_dataset_card - -# isort: split - -import fairseq2.datasets.asr -import fairseq2.datasets.instruction -import fairseq2.datasets.parallel_text +from fairseq2.datasets.handler import DatasetHandler as DatasetHandler +from fairseq2.datasets.handler import DatasetLoader as DatasetLoader +from fairseq2.datasets.handler import DatasetNotFoundError as DatasetNotFoundError +from fairseq2.datasets.handler import StandardDatasetHandler as StandardDatasetHandler +from fairseq2.datasets.handler import get_dataset_family as get_dataset_family +from fairseq2.datasets.static import load_dataset as load_dataset diff --git a/src/fairseq2/datasets/asr.py b/src/fairseq2/datasets/asr.py index 8659964a0..1bfc31a46 100644 --- a/src/fairseq2/datasets/asr.py +++ b/src/fairseq2/datasets/asr.py @@ -7,14 +7,16 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, Set, Union, cast, final +from typing import Any, Final, cast, final import torch from torch import Tensor from torch.nn.functional import layer_norm +from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -26,21 +28,21 @@ read_sequence, ) from fairseq2.data.audio import AudioDecoder -from fairseq2.data.text import ( - StrSplitter, - TextTokenizer, - default_raw_sentencepiece_tokenizer_loader, - load_text_tokenizer, - read_text, +from fairseq2.data.text import StrSplitter, TextTokenizer, read_text +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, ) -from fairseq2.datasets.batching import LengthBatching, StaticBatching from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.error import DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.error import DatasetError, SplitNotFoundError +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.typing import DataType, override +from fairseq2.typing import DataType class AsrDataset(ABC): @@ -52,21 +54,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_audio_len: int, max_audio_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - dtype: DataType = torch.float32, - min_audio_len: int = 1, - normalize_audio: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: AsrReadOptions | None = None, ) -> DataReader[Seq2SeqBatch]: """Create a dataset reader. @@ -76,94 +67,77 @@ def create_reader( The tokenizer to encode target text. :param gang: The gang over which to shard the dataset. + :param min_audio_len: + The minimum audio length of each example. Examples shorter than this + value will be dropped. :param max_audio_len: - The maximum audio length of each example. Examples longer than - this value will be dropped. + The maximum audio length of each example. Examples longer than this + value will be dropped. :param batching: The batching strategy for returned examples. - :param dtype: - The data type of the decoded audio sequences. - :param min_audio_len: - The minimum audio length of each example. Examples shorter than - this value will be dropped. - :param normalize_audio: - If ``True``, normalizes audio to have zero mean and unit variance. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. + :param options: + The read options. """ @abstractmethod - def splits(self) -> Set[str]: + def splits(self) -> set[str]: """Return the set of splits.""" -load_asr_dataset = DelegatingDatasetLoader[AsrDataset]() +@dataclass +class AsrReadOptions(DataReadOptions): + dtype: DataType = torch.float32 + """The data type of the decoded audio sequences.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" # TODO: FIX, INFER npc = 10 +GENERIC_ASR_DATASET_FAMILY: Final = "generic_asr" + + # TODO: Work in progress! @final class GenericAsrDataset(AsrDataset): """Represents a generic manifest-based ASR dataset.""" + _name: str _manifest_dir: Path - _splits: Set[str] + _splits: set[str] - def __init__(self, manifest_dir: Path, splits: Set[str]) -> None: + def __init__(self, name: str, manifest_dir: Path, splits: set[str]) -> None: """ :param manifest_dir: The directory under which the manifest files resides. :param splits: The available splits. """ + self._name = name self._manifest_dir = manifest_dir self._splits = splits - @classmethod - def from_path(cls, path: Path) -> GenericAsrDataset: - """Load a :class:`GenericAsrDataset` from ``path``.""" + @staticmethod + def from_path(path: Path, name: str | None = None) -> GenericAsrDataset: + if name is None: + name = f"path:{path.name}" + path = path.expanduser().resolve() if not path.is_dir(): - return GenericAsrDataset(manifest_dir=path.parent, splits={path.stem}) + return GenericAsrDataset(name, manifest_dir=path.parent, splits={path.stem}) try: splits = {f.stem for f in path.glob("*.tsv")} except OSError as ex: - raise RuntimeError( - "The splits cannot be determined. See nested exception for details." + raise DatasetError( + f"The splits under the '{path}' directory cannot be determined. See the nested exception for details." ) from ex - return GenericAsrDataset(path, splits) + return GenericAsrDataset(name, path, splits) @override def create_reader( @@ -171,22 +145,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_audio_len: int, max_audio_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - dtype: DataType = torch.float32, - min_audio_len: int = 1, - normalize_audio: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - cached_fd_count: int = 1000, - **extras: Any, + batching: Batching, + options: AsrReadOptions | None = None, ) -> DataPipelineReader[Seq2SeqBatch]: """ :param cached_fd_count: @@ -194,17 +156,20 @@ def create_reader( audio files. """ if split not in self._splits: - raise ValueError( - f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits))}" - ) + raise SplitNotFoundError(self._name, split, self._splits) + + if options is None: + options = AsrReadOptions() + + seed = options.seed audio_dir = self._retrieve_data_directory(split) builder = self._read_manifest(split) # Shuffle examples. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(example_shuffle_window, seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed) seed += 1 @@ -216,8 +181,8 @@ def create_reader( if isinstance(batching, LengthBatching): # Bucket by the audio length. bucket_sizes = create_bucket_sizes( - max_seq_len=max_audio_len, min_seq_len=min_audio_len, + max_seq_len=max_audio_len, max_num_elements=batching.max_num_elements, num_seqs_multiple_of=8, ) @@ -228,11 +193,11 @@ def create_reader( min_data_len=min_audio_len, skip_below_min_examples=True, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) - else: + elif isinstance(batching, StaticBatching): # Filter out out-of-range audios. - def skip(example: Dict[str, Any]) -> bool: + def skip(example: dict[str, object]) -> bool: audio_len = cast(int, example["audio_size"]) return audio_len >= min_audio_len and audio_len <= max_audio_len @@ -240,21 +205,31 @@ def skip(example: Dict[str, Any]) -> bool: builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) seed += 1 # Memory map audio files. + cached_fd_count = options.extras.get("cached_fd_count", 1) + if not isinstance(cached_fd_count, int): + raise TypeError( + f"`options.extras['cached_fd_count']` must be of type `int`, but is of type `{type(cached_fd_count)}` instead." + ) + file_mapper = FileMapper(audio_dir, cached_fd_count=cached_fd_count) builder.map(file_mapper, selector="[*].audio") # Decode audio. - audio_decoder = AudioDecoder(dtype=torch.float32 if normalize_audio else dtype) + audio_decoder = AudioDecoder( + dtype=torch.float32 if options.normalize_audio else options.dtype + ) builder.map(audio_decoder, selector="[*].audio.data") @@ -265,9 +240,9 @@ def normalize(waveform: Tensor) -> Tensor: with torch.no_grad(): waveform = layer_norm(waveform, waveform.shape) - return waveform.to(dtype) + return waveform.to(options.dtype) - if normalize_audio: + if options.normalize_audio: builder.map(normalize, selector="[*].audio.data.waveform") # Tokenize target text. @@ -285,14 +260,14 @@ def normalize(waveform: Tensor) -> Tensor: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `Seq2SeqBatch`. - def to_batch(example: Dict[str, Any]) -> Seq2SeqBatch: + def to_batch(example: dict[str, Any]) -> Seq2SeqBatch: source_data = cast(SequenceData, example["audio"]["data"]["waveform"]) target_data = cast(SequenceData, example["text"]) @@ -314,11 +289,13 @@ def to_batch(example: Dict[str, Any]) -> Seq2SeqBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[Seq2SeqBatch]( + self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) def _retrieve_data_directory(self, split: str) -> Path: @@ -329,14 +306,15 @@ def _retrieve_data_directory(self, split: str) -> Path: line = fp.readline().rstrip() except OSError as ex: raise DatasetError( - f"{manifest_file} cannot be read. See nested exception for details." + self._name, + f"The {manifest_file} manifest file cannot be read. See the nested exception for details.", ) from ex try: return Path(line) except ValueError: raise DatasetError( - f"The first line of {manifest_file} must point to a data directory." + f"The first line of the '{manifest_file}' manifest file must point to a data directory." ) from None def _read_manifest(self, split: str) -> DataPipelineBuilder: @@ -372,26 +350,11 @@ def read_wrd_file() -> DataPipelineBuilder: return read_sequence(manifest) @override - def splits(self) -> Set[str]: + def splits(self) -> set[str]: return self._splits -@final -class GenericAsrDatasetLoader(AbstractDatasetLoader[GenericAsrDataset]): - @override - def _load(self, path: Path, card: AssetCard) -> GenericAsrDataset: - try: - return GenericAsrDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_asr_dataset = GenericAsrDatasetLoader() - -load_asr_dataset.register("generic_asr", load_generic_asr_dataset) - -load_librispeech_asr_tokenizer = default_raw_sentencepiece_tokenizer_loader - -load_text_tokenizer.register("librispeech_asr", load_librispeech_asr_tokenizer) +def load_asr_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> AsrDataset: + return load_dataset(name_or_card, AsrDataset, force=force) diff --git a/src/fairseq2/datasets/batching.py b/src/fairseq2/datasets/batching.py deleted file mode 100644 index e5bbd8493..000000000 --- a/src/fairseq2/datasets/batching.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass -class StaticBatching: - """Specifies batching where each batch has the same number of examples.""" - - batch_size: int - """The number of examples in each batch.""" - - -@dataclass -class LengthBatching: - """Specifies batching where each batch has a maximum number of elements.""" - - max_num_elements: int - """The maximum number of elements (e.g. tokens) in each batch.""" diff --git a/src/fairseq2/datasets/config.py b/src/fairseq2/datasets/config.py new file mode 100644 index 000000000..8eda42644 --- /dev/null +++ b/src/fairseq2/datasets/config.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import MutableMapping +from dataclasses import dataclass, field +from typing import TypeAlias + +from fairseq2.datasets.data_reader import SyncMode + + +@dataclass +class StaticBatching: + """Specifies batching where each batch has the same number of examples.""" + + batch_size: int + """The number of examples in each batch.""" + + +@dataclass +class LengthBatching: + """Specifies batching where each batch has a maximum number of elements.""" + + max_num_elements: int + """The maximum number of elements (e.g. tokens) in each batch.""" + + +Batching: TypeAlias = StaticBatching | LengthBatching + + +@dataclass(kw_only=True) +class DataReadOptions: + example_shuffle_window: int = 1 + """ + The size of the sliding window for shuffling examples. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + batch_shuffle_window: int = 1 + """ + The size of the sliding window for shuffling batches. If ``1``, no + shuffling is performed; if ``0``, true shuffling is performed by loading the + entire dataset. + """ + + drop_remainder: bool = False + """ + If ``True``, drops the last set of batches if they have in total fewer + examples than requested. + """ + + sync_batches: bool = True + """ + If ``True``, ensures that each process in ``gang`` reads the same number of + batches. Typically used when the amount of data to be read can vary per + process (e.g. due to unbalanced sharding or non-static batching) and it is + critical for each process to iterate over the same number of batches (e.g. + during training). + """ + + sync_mode: SyncMode = "until_first" + """ + If ``until_first``, stops iteration on all ranks when one of the ranks + reaches its end of data. If ``until_last``, stops iteration when all ranks + reach their end of data; ranks that have already reached their end of data + will return an empty list of batches. + """ + + max_num_batches: int | None = None + """The maximum number of batches to return.""" + + num_accumulate: int = 1 + """ + The number of batches to accumulate in each iteration. Typically used with + gradient accumulation during training. + """ + + num_prefetch: int = 1 + """The number of batches to prefetch in background.""" + + seed: int = 2 + """The seed to initialize the random number generators used internally.""" + + extras: MutableMapping[str, object] = field(default_factory=dict) + """The reader-specific extra options.""" diff --git a/src/fairseq2/datasets/data_reader.py b/src/fairseq2/datasets/data_reader.py index d975c5bca..478cb35ea 100644 --- a/src/fairseq2/datasets/data_reader.py +++ b/src/fairseq2/datasets/data_reader.py @@ -7,25 +7,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator, List, Mapping, TypeVar, final +from collections.abc import Iterator, Mapping +from typing import Literal, TypeAlias, TypeVar, final -from typing_extensions import Self +from typing_extensions import Self, override -from fairseq2.data import DataPipeline -from fairseq2.datasets.utils import _reduce_num_batches -from fairseq2.gang import Gang -from fairseq2.logging import get_log_writer -from fairseq2.typing import override - -log = get_log_writer(__name__) - - -BatchT = TypeVar("BatchT") +from fairseq2.data import DataPipeline, DataPipelineError +from fairseq2.datasets.error import DataReadError +from fairseq2.datasets.utils import _min_num_batches, _sum_num_batches +from fairseq2.gang import Gang, GangError BatchT_co = TypeVar("BatchT_co", covariant=True) -class DataReader(ABC, Iterator[List[BatchT_co]]): +class DataReader(ABC, Iterator[list[BatchT_co]]): """Reads batches of examples from a dataset.""" @abstractmethod @@ -33,7 +28,7 @@ def __iter__(self) -> Self: ... @abstractmethod - def __next__(self) -> List[BatchT_co]: + def __next__(self) -> list[BatchT_co]: ... @abstractmethod @@ -41,11 +36,11 @@ def reset(self) -> None: """Reset state and move back to the first batch.""" @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, object]: ... @abstractmethod - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: ... @property @@ -54,10 +49,17 @@ def num_accumulate(self) -> int: """The number of batches accumulated in each iteration.""" +SyncMode: TypeAlias = Literal["until_first", "until_last"] + + +BatchT = TypeVar("BatchT") + + @final class DataPipelineReader(DataReader[BatchT]): """Reads batches of examples from a dataset using a :class:`DataPipeline`.""" + _name: str _pipeline: DataPipeline _pipeline_iter: Iterator[BatchT] _gang: Gang @@ -67,14 +69,17 @@ class DataPipelineReader(DataReader[BatchT]): def __init__( self, + name: str, pipeline: DataPipeline, gang: Gang, *, num_accumulate: int = 1, drop_remainder: bool = True, - sync_batches: bool = False, + sync_batches: bool = True, + sync_mode: SyncMode = "until_first", ) -> None: """ + :param name: The name of the dataset. :param pipeline: The data pipeline to iterate over. :param gang: @@ -90,13 +95,20 @@ def __init__( across all processes in the gang. Typically used when the amount of data to be read can vary per process (e.g. due to bucketing) and it is critical for each process to iterate over same number of batches. + :param sync_mode: + If ``until_first``, stops iteration when the first rank reaches end + of data. If ``until_last``, stops iteration when the last rank + reaches end of data; ranks that have already reached their end of + data will return an empty list of batches. """ + self._name = name self._pipeline = pipeline self._pipeline_iter = iter(pipeline) self._gang = gang self._num_accumulate = num_accumulate self._drop_remainder = drop_remainder self._sync_batches = sync_batches + self._sync_until_last = sync_mode == "until_last" self._eod = False @override @@ -104,7 +116,7 @@ def __iter__(self) -> Self: return self @override - def __next__(self) -> List[BatchT]: + def __next__(self) -> list[BatchT]: if self._eod: raise StopIteration() @@ -115,20 +127,37 @@ def __next__(self) -> List[BatchT]: batch = next(self._pipeline_iter) except StopIteration: break + except DataPipelineError as ex: + raise DataReadError( + self._name, "The data pipeline has failed to read the next batch. See the nested exception for details." # fmt: skip + ) from ex batches.append(batch) - if self._sync_batches and self._gang.size > 1: - num_batches = _reduce_num_batches(len(batches), self._gang, log) - - batches = batches[:num_batches] - # If we read less than `num_accumulate` batches, it means we reached end # of data. if self._drop_remainder and len(batches) != self._num_accumulate: batches.clear() - self._eod = len(batches) == 0 + local_num_batches = len(batches) + + if self._sync_batches and self._gang.size > 1: + try: + if self._sync_until_last: + num_batches = _sum_num_batches(local_num_batches, self._gang) + else: + num_batches = _min_num_batches(local_num_batches, self._gang) + + if num_batches != local_num_batches: + batches = batches[:num_batches] + except GangError as ex: + raise DataReadError( + self._name, "The batch synchronization of the gang processes has failed. See the nested exception for details." # fmt: skip + ) from ex + else: + num_batches = local_num_batches + + self._eod = num_batches == 0 if self._eod: raise StopIteration() @@ -142,11 +171,11 @@ def reset(self) -> None: self._pipeline.reset() @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, object]: return self._pipeline.state_dict() @override - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: self._eod = False self._pipeline.load_state_dict(state_dict) diff --git a/src/fairseq2/datasets/error.py b/src/fairseq2/datasets/error.py index 5a0b94f1d..71e312ec3 100644 --- a/src/fairseq2/datasets/error.py +++ b/src/fairseq2/datasets/error.py @@ -6,6 +6,29 @@ from __future__ import annotations +from collections.abc import Set -class DatasetError(RuntimeError): - """Raised when a dataset can't be read.""" + +class DatasetError(Exception): + pass + + +class DataReadError(Exception): + pass + + +class SplitNotFoundError(LookupError): + name: str + split: str + available_splits: Set[str] + + def __init__(self, name: str, split: str, available_splits: Set[str]) -> None: + s = ", ".join(sorted(available_splits)) + + super().__init__( + f"`split` must be one of the following splits, but is '{split}' instead: {s}" + ) + + self.name = name + self.split = split + self.available_splits = available_splits diff --git a/src/fairseq2/datasets/handler.py b/src/fairseq2/datasets/handler.py new file mode 100644 index 000000000..3ec2f7334 --- /dev/null +++ b/src/fairseq2/datasets/handler.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Protocol, final + +from typing_extensions import override + +from fairseq2.assets import AssetCard, AssetDownloadManager, AssetError +from fairseq2.datasets.error import DatasetError + + +class DatasetHandler(ABC): + @abstractmethod + def load(self, card: AssetCard, *, force: bool) -> object: + ... + + @abstractmethod + def load_from_path(self, path: Path) -> object: + ... + + @property + @abstractmethod + def kls(self) -> type: + ... + + +class DatasetNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known dataset.") + + self.name = name + + +class DatasetLoader(Protocol): + def __call__(self, path: Path, name: str | None) -> object: + ... + + +@final +class StandardDatasetHandler(DatasetHandler): + _kls: type + _loader: DatasetLoader + _asset_download_manager: AssetDownloadManager + + def __init__( + self, + kls: type, + loader: DatasetLoader, + asset_download_manager: AssetDownloadManager, + ) -> None: + self._kls = kls + self._loader = loader + self._asset_download_manager = asset_download_manager + + @override + def load(self, card: AssetCard, *, force: bool) -> object: + dataset_uri = card.field("data").as_uri() + + path = self._asset_download_manager.download_dataset( + dataset_uri, card.name, force=force + ) + + try: + return self._loader(path, card.name) + except DatasetError as ex: + raise AssetError( + f"The constructor of the '{card.name}' dataset has raised an error. See the nested exception for details." + ) from ex + + @override + def load_from_path(self, path: Path) -> object: + return self._loader(path, name=None) + + @override + @property + def kls(self) -> type: + return self._kls + + +def get_dataset_family(card: AssetCard) -> str: + return card.field("dataset_family").as_(str) # type: ignore[no-any-return] diff --git a/src/fairseq2/datasets/instruction.py b/src/fairseq2/datasets/instruction.py index 76146c915..db33a5120 100644 --- a/src/fairseq2/datasets/instruction.py +++ b/src/fairseq2/datasets/instruction.py @@ -8,13 +8,15 @@ import json from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union, cast, final +from typing import Any, Final, cast, final import torch -from typing_extensions import NoReturn +from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -25,14 +27,20 @@ read_sequence, ) from fairseq2.data.text import TextTokenizer -from fairseq2.datasets.batching import LengthBatching, StaticBatching +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.error import DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.error import DatasetError, SplitNotFoundError +from fairseq2.datasets.static import load_dataset +from fairseq2.datasets.utils import _load_files_and_weights +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.typing import override class InstructionDataset(ABC): @@ -41,81 +49,49 @@ class InstructionDataset(ABC): @abstractmethod def create_reader( self, + split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: InstructionReadOptions | None = None, ) -> DataReader[SequenceBatch]: """Create a dataset reader. + :param split: + The split to read. :param tokenizer: The tokenizer to encode text. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param sample: - If ``True``, instruction sources (e.g. files) will be sampled in - proportion to their weights. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. + :param options: + The read options. """ @abstractmethod def create_prompt_reader( self, + split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: StaticBatching, - *, - drop_remainder: bool = False, - sync_batches: bool = True, - num_prefetch: int = 1, - **extras: Any, + options: InstructionPromptReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: """Create a dataset reader for evaluation. + :param split: + The split to read. :param tokenizer: The tokenizer to encode text. :param gang: @@ -125,173 +101,139 @@ def create_prompt_reader( this value will be dropped. :param batching: The batching strategy for returned examples. - :param drop_remainder: - If ``True``, drops the last batch if it has fewer examples than - requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param num_prefetch: - The number of batches to prefetch in background. - :param extras: - The extra parameters specific to the dataset implementation. + :param options: + The read options. """ + @abstractmethod + def splits(self) -> set[str]: + """Return the set of splits.""" + + +@dataclass +class InstructionReadOptions(DataReadOptions): + sample: bool = False + """ + If ``True``, instruction sources (e.g. JSONL files) will be sampled in + proportion to their weights. + """ + + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + target_encode_mode: str = "prompt_response" + """The tokenizer mode to encode the target text.""" -load_instruction_dataset = DelegatingDatasetLoader[InstructionDataset]() + +@dataclass +class InstructionPromptReadOptions(DataReadOptions): + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" # TODO: FIX, INFER npc = 10 +GENERIC_INSTRUCTION_DATASET_FAMILY: Final = "generic_instruction" + + # TODO: Work in progress! @final class GenericInstructionDataset(InstructionDataset): """Represents a generic JSONL instruction dataset.""" - _files: Sequence[Path] - _weights: Sequence[float] + _name: str + _splits: dict[str, tuple[Sequence[Path], Sequence[float]]] - def __init__(self, files: Sequence[Path], weights: Sequence[float]) -> None: + def __init__( + self, name: str, splits: dict[str, tuple[Sequence[Path], Sequence[float]]] + ) -> None: """ :param files: The instruction files. :param weights: The weight of each file in ``files``. """ - if len(files) != len(weights): - raise ValueError( - f"The lengths of `files` and `weights` must match, but they are {len(files)} and {len(weights)} instead." - ) + self._name = name - self._files = files - self._weights = weights + for split, (files, weights) in splits.items(): + if len(files) != len(weights): + raise ValueError( + f"The lengths of the file and weight lists of the '{split}' split must match, but they are {len(files)} and {len(weights)} instead." + ) - @classmethod - def from_path(cls, path: Path) -> GenericInstructionDataset: - """Load a :class:`InstructionDataset` from ``path``.""" - path = path.expanduser().resolve() + self._splits = splits - if not path.is_dir(): - return GenericInstructionDataset(files=[path], weights=[1.0]) + @staticmethod + def from_path(path: Path, name: str | None = None) -> GenericInstructionDataset: + if name is None: + name = f"path:{path.name}" - manifest_file = path.joinpath("MANIFEST") + splits: dict[str, tuple[Sequence[Path], Sequence[float]]] = {} - try: - with manifest_file.open() as fp: - content = list(fp) - except FileNotFoundError: - content = None - except OSError as ex: - raise RuntimeError( - f"{manifest_file} cannot be read. See nested exception for details." - ) from ex - - # If the directory does not contain a MANIFEST file, treat all JSONL - # files as part of the dataset with equal weight. - if content is None: + if path.is_dir(): try: - files = list(path.glob("**/*.jsonl")) + child_dirs = [p for p in path.iterdir() if p.is_dir()] except OSError as ex: - raise RuntimeError( - f"The JSONL files under {path} cannot be retrieved. See nested exception for details." - ) from ex - - weights = [1.0 for _ in range(len(files))] - - return GenericInstructionDataset(files, weights=weights) - - # Sort the JSONL files in alphabetical order. - content.sort() - - files = [] - - weights = [] - - # Each line of the MANIFEST file corresponds to the path of a JSONL file - # and its weight (e.g. number of examples). - for idx, line in enumerate(content): - - def raise_error() -> NoReturn: raise DatasetError( - f"Each line in {manifest_file} must represent a path to a JSONL file and a weight, but line {idx} is '{line}' instead." - ) from None - - fields = line.rstrip().split("\t") - - if len(fields) != 2: - raise_error() - - file_path = fields[0].strip() - if not file_path: - raise_error() - - try: - file = path.joinpath(file_path) - except ValueError: - raise_error() + name, f"The files under the '{path}' directory cannot be retrieved. See the nested exception for details." # fmt: skip + ) from ex - if not file.exists(): - raise DatasetError( - f"The file '{file}' referred at line {idx} in {manifest_file} does not exist." - ) + for child_dir in child_dirs: + files, weights = _load_files_and_weights(name, child_dir) - files.append(file) + splits[child_dir.name] = (files, weights) - try: - weight = float(fields[1].strip()) - except ValueError: - raise_error() + if not splits: + files, weights = _load_files_and_weights(name, path) - weights.append(weight) + splits["default"] = (files, weights) - return GenericInstructionDataset(files, weights) + return GenericInstructionDataset(name, splits) @override def create_reader( self, + split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: InstructionReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: - if len(self._files) == 1: - builder = self._read_jsonl(self._files[0], tokenizer) + files_weights = self._splits.get(split) + if files_weights is None: + raise SplitNotFoundError(self._name, split, self._splits.keys()) + + if options is None: + options = InstructionReadOptions() + + seed = options.seed + + files, weights = files_weights + + if len(files) == 1: + builder = self._read_jsonl(files[0], tokenizer) else: pipelines = [] - for file in self._files: + for file in files: pipeline = self._read_jsonl(file, tokenizer).and_return() pipelines.append(pipeline) - if sample: - builder = DataPipeline.sample( - pipelines, weights=self._weights, seed=seed - ) + if options.sample: + builder = DataPipeline.sample(pipelines, weights=weights, seed=seed) seed += 1 else: builder = DataPipeline.concat(pipelines) # Shuffle files. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(shuffle_window=0, seed=seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed=seed) seed += 1 @@ -300,22 +242,22 @@ def create_reader( seed += gang.rank - # Encode prompt and target texts. - prompt_encoder = tokenizer.create_encoder(mode="prompt") - target_encoder = tokenizer.create_encoder(mode="prompt_response") + # Encode source and target texts. + source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) - builder.map(prompt_encoder, selector="src", num_parallel_calls=npc) + builder.map(source_encoder, selector="src", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt", num_parallel_calls=npc) - def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]: + def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: id_ = example.get("id") - prompt_indices = example["src"] + source_indices = example["src"] target_indices = example["tgt"] - indices = torch.cat([prompt_indices, target_indices]) + indices = torch.cat([source_indices, target_indices]) - target_mask = torch.arange(len(indices)) >= len(prompt_indices) + target_mask = torch.arange(len(indices)) >= len(source_indices) return {"id": id_, "indices": indices, "target_mask": target_mask} @@ -323,29 +265,36 @@ def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]: if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( - max_seq_len=max_seq_len, max_num_elements=batching.max_num_elements + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + max_num_elements=batching.max_num_elements, ) # Bucket by the sequence length. builder.bucket_by_length( bucket_sizes, selector="indices", + min_data_len=min_seq_len, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) - else: + elif isinstance(batching, StaticBatching): # Filter out long examples. - def skip(example: Dict[str, Any]) -> bool: - return len(example["indices"]) <= max_seq_len + def skip(example: dict[str, Any]) -> bool: + seq_len = len(example["indices"]) + + return seq_len >= min_seq_len and seq_len <= max_seq_len builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed=seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed=seed) seed += 1 @@ -359,14 +308,14 @@ def skip(example: Dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `SequenceBatch`. - def to_batch(example: Dict[str, Any]) -> SequenceBatch: + def to_batch(example: dict[str, Any]) -> SequenceBatch: indices = cast(SequenceData, example["indices"]) seqs, padding_mask = get_seqs_and_padding_mask(indices, gang.device) @@ -378,32 +327,40 @@ def to_batch(example: Dict[str, Any]) -> SequenceBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[SequenceBatch]( + self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) @override def create_prompt_reader( self, + split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, batching: StaticBatching, - *, - drop_remainder: bool = False, - sync_batches: bool = True, - num_prefetch: int = 1, - **extras: Any, + options: InstructionPromptReadOptions | None = None, ) -> DataPipelineReader[SequenceBatch]: - if len(self._files) == 1: - builder = self._read_jsonl(self._files[0], tokenizer) + try: + files, weights = self._splits[split] + except KeyError: + raise SplitNotFoundError(self._name, split, self._splits.keys()) from None + + if options is None: + options = InstructionPromptReadOptions() + + if len(files) == 1: + builder = self._read_jsonl(files[0], tokenizer) else: pipelines = [] - for file in self._files: + for file in files: pipeline = self._read_jsonl(file, tokenizer).and_return() pipelines.append(pipeline) @@ -413,28 +370,30 @@ def create_prompt_reader( # Shard builder.shard(gang.rank, gang.size, allow_uneven=True) - # Encode prompt texts. - text_encoder = tokenizer.create_encoder(mode="prompt") + # Encode source texts. + text_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) - def encode(example: Dict[str, Any]) -> Dict[str, Any]: + def encode(example: dict[str, Any]) -> dict[str, Any]: id_ = example.get("id") - prompt = example["src"] + source = example["src"] - indices = text_encoder(prompt) + indices = text_encoder(source) - return {"id": id_, "prompt": prompt, "indices": indices} + return {"id": id_, "prompt": source, "indices": indices} builder.map(encode, num_parallel_calls=npc) # Filter out long examples. - def skip(example: Dict[str, Any]) -> bool: - return len(example["indices"]) <= max_seq_len + def skip(example: dict[str, Any]) -> bool: + seq_len = len(example["indices"]) + + return seq_len >= min_seq_len and seq_len <= max_seq_len builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) # Collate bucketed examples into a batch. collater = Collater(pad_value=tokenizer.vocab_info.pad_idx or 0) @@ -442,10 +401,10 @@ def skip(example: Dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) # Wrap examples with `SequenceBatch`. - def to_batch(example: Dict[str, Any]) -> SequenceBatch: + def to_batch(example: dict[str, Any]) -> SequenceBatch: indices = cast(SequenceData, example["indices"]) seqs, padding_mask = get_seqs_and_padding_mask(indices, gang.device) @@ -455,7 +414,12 @@ def to_batch(example: Dict[str, Any]) -> SequenceBatch: pipeline = builder.map(to_batch).and_return() return DataPipelineReader[SequenceBatch]( - pipeline, gang, drop_remainder=drop_remainder, sync_batches=sync_batches + self._name, + pipeline, + gang, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuilder: @@ -468,21 +432,12 @@ def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuild return read_sequence(lines).map(json.loads, num_parallel_calls=npc) - -@final -class GenericInstructionDatasetLoader(AbstractDatasetLoader[GenericInstructionDataset]): @override - def _load(self, path: Path, card: AssetCard) -> GenericInstructionDataset: - try: - return GenericInstructionDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - + def splits(self) -> set[str]: + return set(self._splits.keys()) -load_generic_instruction_dataset = GenericInstructionDatasetLoader() -load_instruction_dataset.register( - "generic_instruction", load_generic_instruction_dataset -) +def load_instruction_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> InstructionDataset: + return load_dataset(name_or_card, InstructionDataset, force=force) diff --git a/src/fairseq2/datasets/loader.py b/src/fairseq2/datasets/loader.py deleted file mode 100644 index ee4ca6b1c..000000000 --- a/src/fairseq2/datasets/loader.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Dict, Optional, Protocol, TypeVar, Union, final - -from fairseq2.assets import ( - AssetCard, - AssetCardError, - AssetDownloadManager, - AssetError, - AssetStore, - default_asset_store, - default_download_manager, -) - -DatasetT = TypeVar("DatasetT") - -DatasetT_co = TypeVar("DatasetT_co", covariant=True) - - -class DatasetLoader(Protocol[DatasetT_co]): - """Loads datasets of type ``DatasetT```.""" - - def __call__( - self, - dataset_name_or_card: Union[str, AssetCard], - *, - force: bool = False, - progress: bool = True, - ) -> DatasetT_co: - """ - :param dataset_name_or_card: - The name or the asset card of the dataset to load. - :param force: - If ``True``, downloads the dataset even if it is already in cache. - :param progress: - If ``True``, displays a progress bar to stderr. - """ - - -class AbstractDatasetLoader(ABC, DatasetLoader[DatasetT]): - """Provides a skeletal implementation of :class:`DatasetLoader`.""" - - _asset_store: AssetStore - _download_manager: AssetDownloadManager - - def __init__( - self, - *, - asset_store: Optional[AssetStore] = None, - download_manager: Optional[AssetDownloadManager] = None, - ) -> None: - """ - :param asset_store: - The asset store where to check for available datasets. If ``None``, - the default asset store will be used. - :param download_manager: - The download manager. If ``None``, the default download manager will - be used. - """ - self._asset_store = asset_store or default_asset_store - self._download_manager = download_manager or default_download_manager - - @final - def __call__( - self, - dataset_name_or_card: Union[str, AssetCard], - *, - force: bool = False, - progress: bool = True, - ) -> DatasetT: - if isinstance(dataset_name_or_card, AssetCard): - card = dataset_name_or_card - else: - card = self._asset_store.retrieve_card(dataset_name_or_card) - - dataset_uri = card.field("data").as_uri() - - try: - path = self._download_manager.download_dataset( - dataset_uri, card.name, force=force, progress=progress - ) - except ValueError as ex: - raise AssetCardError( - f"The value of the field 'data' of the asset card '{card.name}' must be a URI. See nested exception for details." - ) from ex - - try: - return self._load(path, card) - except ValueError as ex: - raise AssetError( - f"The {card.name} dataset cannot be loaded. See nested exception for details." - ) from ex - - @abstractmethod - def _load(self, path: Path, card: AssetCard) -> DatasetT: - """ - :param path: - The path to the dataset. - :param card: - The asset card of the dataset. - """ - - -@final -class DelegatingDatasetLoader(DatasetLoader[DatasetT]): - """Loads datasets of type ``DatasetT`` using registered loaders.""" - - _asset_store: AssetStore - _loaders: Dict[str, DatasetLoader[DatasetT]] - - def __init__(self, *, asset_store: Optional[AssetStore] = None) -> None: - """ - :param asset_store: - The asset store where to check for available datasets. If ``None``, - the default asset store will be used. - """ - self._asset_store = asset_store or default_asset_store - - self._loaders = {} - - def __call__( - self, - dataset_name_or_card: Union[str, AssetCard], - *, - force: bool = False, - progress: bool = True, - ) -> DatasetT: - if isinstance(dataset_name_or_card, AssetCard): - card = dataset_name_or_card - else: - card = self._asset_store.retrieve_card(dataset_name_or_card) - - family = card.field("dataset_family").as_(str) - - try: - loader = self._loaders[family] - except KeyError: - raise AssetError( - f"The value of the field 'dataset_family' of the asset card '{card.name}' must be a supported dataset family, but '{family}' has no registered loader." - ) from None - - return loader(card, force=force, progress=progress) - - def register(self, family: str, loader: DatasetLoader[DatasetT]) -> None: - """Register a dataset loader to use with this loader. - - :param family: - The dataset type. If the 'dataset_family' field of an asset card - matches this value, the specified ``loader`` will be used. - :param loader: - The dataset loader. - """ - if family in self._loaders: - raise ValueError( - f"`family` must be a unique dataset family name, but '{family}' has already a registered loader." - ) - - self._loaders[family] = loader - - def supports(self, dataset_name_or_card: Union[str, AssetCard]) -> bool: - """Return ``True`` if the specified dataset has a registered loader.""" - if isinstance(dataset_name_or_card, AssetCard): - card = dataset_name_or_card - else: - card = self._asset_store.retrieve_card(dataset_name_or_card) - - family = card.field("dataset_family").as_(str) - - return family in self._loaders - - -def is_dataset_card(card: AssetCard) -> bool: - """Return ``True`` if ``card`` specifies a dataset.""" - return card.field("dataset_family").exists() - - -def get_dataset_family(card: AssetCard) -> str: - """Return the dataset family name contained in ``card``.""" - return card.field("dataset_family").as_(str) # type: ignore[no-any-return] diff --git a/src/fairseq2/datasets/parallel_text.py b/src/fairseq2/datasets/parallel_text.py index 4a1b9a10e..f53d9e42d 100644 --- a/src/fairseq2/datasets/parallel_text.py +++ b/src/fairseq2/datasets/parallel_text.py @@ -10,11 +10,11 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast, final +from typing import Any, Final, cast, final -from typing_extensions import NoReturn +from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( Collater, DataPipeline, @@ -23,36 +23,20 @@ create_bucket_sizes, ) from fairseq2.data.text import TextTokenizer, read_text -from fairseq2.datasets.batching import LengthBatching, StaticBatching +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.error import DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.error import DatasetError, SplitNotFoundError +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.typing import Device, override - - -@dataclass(unsafe_hash=True) # Due to FSDP, we cannot freeze. -class Direction: - """Represents the language direction of a parallel corpus.""" - - source_lang: str - """The source language code.""" - - target_lang: str - """The target language code.""" - - origin: Optional[str] = None - """The origin of data. Typically used to indicate mined or synthetic data.""" - - def __repr__(self) -> str: - s = f"{self.source_lang}-{self.target_lang}" - - if self.origin: - s = f"{self.origin}/{s}" - - return s +from fairseq2.typing import Device class ParallelTextDataset(ABC): @@ -64,21 +48,10 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - direction: Optional[Direction] = None, - min_seq_len: int = 1, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: ParallelTextReadOptions | None = None, ) -> DataReader[Seq2SeqBatch]: """Create a dataset reader. @@ -88,76 +61,78 @@ def create_reader( The tokenizer to encode text. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param direction: - The direction to read. If ``None``, all directions will be read. - :param min_seq_len: - The minimum sequence length of each example. Examples shorter than - this value will be dropped. - :param sample: - If ``True``, corpora will be sampled in proportion to their weights. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. + :param options: + The read options. """ @abstractmethod - def splits(self) -> Set[str]: + def splits(self) -> set[str]: """Return the set of splits.""" @abstractmethod - def directions(self, split: str) -> List[Direction]: + def directions(self, split: str) -> list[Direction]: """Return the directions included ``split``.""" -load_parallel_text_dataset = DelegatingDatasetLoader[ParallelTextDataset]() +@dataclass +class ParallelTextReadOptions(DataReadOptions): + direction: Direction | None = None + """The direction to read. If ``None``, all directions will be read.""" + + sample: bool = False + """If ``True``, corpora will be sampled in proportion to their weights.""" + + +@dataclass(unsafe_hash=True) # Due to FSDP, we cannot freeze. +class Direction: + """Represents the language direction of a parallel corpus.""" + + source_lang: str + """The source language code.""" + + target_lang: str + """The target language code.""" + + origin: str | None = None + """The origin of data. Typically used to indicate mined or synthetic data.""" + + def __repr__(self) -> str: + s = f"{self.source_lang}-{self.target_lang}" + + if self.origin: + s = f"{self.origin}/{s}" + + return s # TODO: FIX, INFER npc = 10 +GENERIC_PARALLEL_TEXT_DATASET_FAMILY: Final = "generic_parallel_text" + + @final class GenericParallelTextDataset(ParallelTextDataset): """Represents a generic file-based parallel text dataset.""" + _name: str _data_dir: Path - _splits: Dict[str, Tuple[List[Direction], List[float]]] + _splits: dict[str, tuple[list[Direction], list[float]]] def __init__( self, - *, + name: str, data_dir: Path, - splits: Dict[str, Tuple[List[Direction], List[float]]], + splits: dict[str, tuple[list[Direction], list[float]]], ) -> None: """ :param data_dir: @@ -166,6 +141,8 @@ def __init__( :param splits: The splits with their directions and their weights. """ + self._name = name + for split, (directions, weights) in splits.items(): if len(directions) != len(weights): raise ValueError( @@ -176,18 +153,24 @@ def __init__( self._splits = splits @classmethod - def from_path(cls, path: Path) -> GenericParallelTextDataset: - """Load a :class:`GenericParallelTextDataset` from ``path``.""" + def from_path( + cls, path: Path, name: str | None = None + ) -> GenericParallelTextDataset: + if name is None: + name = f"path:{path.name}" + path = path.expanduser().resolve() if not path.is_dir(): - raise ValueError("`path` must be a directory with a MANIFEST file.") + raise DatasetError( + name, f"The '{path}' path is expected to be a directory with a MANIFEST file." # fmt: skip + ) try: split_names = [d.name for d in path.iterdir() if d.is_dir()] except OSError as ex: - raise RuntimeError( - "The splits cannot be determined. See nested exception for details." + raise DatasetError( + name, f"The splits under the '{path}' directory cannot be determined. See the nested exception for details." # fmt: skip ) from ex splits = {} @@ -199,8 +182,8 @@ def from_path(cls, path: Path) -> GenericParallelTextDataset: with manifest_file.open() as fp: content = list(fp) except OSError as ex: - raise RuntimeError( - f"{manifest_file} cannot be read. See nested exception for details." + raise DatasetError( + name, f"The '{manifest_file}' file cannot be read. See the nested exception for details." # fmt: skip ) from ex # Sort the directions in alphabetical order. @@ -214,38 +197,38 @@ def from_path(cls, path: Path) -> GenericParallelTextDataset: # its weight (e.g. number of examples) in the split. for idx, line in enumerate(content): - def raise_error() -> NoReturn: - raise DatasetError( - f"Each line in {manifest_file} must represent a valid direction and a weight, but line {idx} is '{line}' instead." - ) from None + def error() -> DatasetError: + return DatasetError( + name, f"Each line in the '{manifest_file}' manifest file must represent a valid direction and a weight, but line {idx} is '{line}' instead." # fmt: skip + ) fields = line.rstrip().split("\t") if len(fields) != 2: - raise_error() + raise error() try: direction = cls._parse_direction(fields[0]) except ValueError: - raise_error() + raise error() from None directions.append(direction) try: weight = float(fields[1].strip()) except ValueError: - raise_error() + raise error() from None weights.append(weight) splits[split] = (directions, weights) - return GenericParallelTextDataset(data_dir=path, splits=splits) + return GenericParallelTextDataset(name, data_dir=path, splits=splits) @staticmethod def _parse_direction(s: str) -> Direction: - def raise_error() -> NoReturn: - raise ValueError( + def value_error() -> ValueError: + return ValueError( f"`s` must represent a valid direction, but is '{s}' instead." ) @@ -256,12 +239,12 @@ def raise_error() -> NoReturn: elif len(parts) == 2: origin, lang_pair = parts else: - raise_error() + raise value_error() parts = lang_pair.split("-") if len(parts) != 2: - raise_error() + raise value_error() source_lang, target_lang = parts @@ -269,7 +252,7 @@ def raise_error() -> NoReturn: target_lang = target_lang.strip() if not source_lang or not target_lang: - raise_error() + raise value_error() return Direction(source_lang, target_lang, origin) @@ -279,28 +262,25 @@ def create_reader( split: str, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - direction: Optional[Direction] = None, - min_seq_len: int = 1, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: ParallelTextReadOptions | None = None, ) -> DataPipelineReader[Seq2SeqBatch]: - try: - directions, weights = self._splits[split] - except KeyError: - self._raise_split_error(split) + directions_weights = self._splits.get(split) + if directions_weights is None: + raise SplitNotFoundError(self._name, split, self._splits.keys()) + + if options is None: + options = ParallelTextReadOptions() + + seed = options.seed + + directions, weights = directions_weights # Determine the directions to read. + direction = options.direction + if direction is not None: if direction not in directions: raise ValueError( @@ -339,7 +319,7 @@ def create_reader( pipelines.append(pipeline) - if sample: + if options.sample: builder = DataPipeline.sample(pipelines, weights=weights, seed=seed) seed += 1 @@ -347,8 +327,8 @@ def create_reader( builder = DataPipeline.concat(pipelines) # Shuffle examples. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(example_shuffle_window, seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed) seed += 1 @@ -358,7 +338,7 @@ def create_reader( seed += gang.rank # Encode source and target texts. - def encode(example: Dict[str, Any]) -> Dict[str, Any]: + def encode(example: dict[str, Any]) -> dict[str, Any]: direction = example["direction"] source_encoder, target_encoder = text_encoders[direction] @@ -372,8 +352,8 @@ def encode(example: Dict[str, Any]) -> Dict[str, Any]: if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( - max_seq_len=max_seq_len, min_seq_len=min_seq_len, + max_seq_len=max_seq_len, max_num_elements=batching.max_num_elements, ) @@ -385,11 +365,11 @@ def encode(example: Dict[str, Any]) -> Dict[str, Any]: min_data_len=min_seq_len, skip_below_min_examples=True, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) - else: + elif isinstance(batching, StaticBatching): # Filter out out-of-range examples. - def skip(example: Dict[str, Any]) -> bool: + def skip(example: dict[str, Any]) -> bool: source_len = len(example["source_indices"]) target_len = len(example["target_indices"]) @@ -400,11 +380,13 @@ def skip(example: Dict[str, Any]) -> bool: builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) seed += 1 @@ -414,22 +396,24 @@ def skip(example: Dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) f = partial(self._to_batch, device=gang.device) pipeline = builder.map(f).and_return() return DataPipelineReader[Seq2SeqBatch]( + self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) def _read_direction(self, split: str, direction: Direction) -> DataPipelineBuilder: @@ -450,12 +434,12 @@ def _read_direction(self, split: str, direction: Direction) -> DataPipelineBuild if not source_file.exists(): raise DatasetError( - f"The source file '{source_file}' is not found under {self._data_dir}." + self._name, f"The source file '{source_file}' is not found under {self._data_dir}." # fmt: skip ) if not target_file.exists(): raise DatasetError( - f"The target file '{target_file}' is not found under {self._data_dir}." + self._name, f"The target file '{target_file}' is not found under {self._data_dir}." # fmt: skip ) source_builder = read_text(source_file, rtrim=True, memory_map=True) @@ -471,7 +455,7 @@ def _read_direction(self, split: str, direction: Direction) -> DataPipelineBuild ) @staticmethod - def _to_batch(example: Dict[str, Any], device: Device) -> Seq2SeqBatch: + def _to_batch(example: dict[str, Any], device: Device) -> Seq2SeqBatch: source_data = cast(SequenceData, example["source_indices"]) target_data = cast(SequenceData, example["target_indices"]) @@ -491,40 +475,19 @@ def _to_batch(example: Dict[str, Any], device: Device) -> Seq2SeqBatch: ) @override - def splits(self) -> Set[str]: + def splits(self) -> set[str]: return set(self._splits.keys()) @override - def directions(self, split: str) -> List[Direction]: - try: - directions, _ = self._splits[split] - except KeyError: - self._raise_split_error(split) + def directions(self, split: str) -> list[Direction]: + directions_weights = self._splits.get(split) + if directions_weights is None: + raise SplitNotFoundError(self._name, split, self._splits.keys()) - return directions + return directions_weights[0] - def _raise_split_error(self, split: str) -> NoReturn: - raise ValueError( - f"`split` must be one of the following splits, but is '{split}' instead: {', '.join(sorted(self._splits.keys()))}" - ) from None - -@final -class GenericParallelTextDatasetLoader( - AbstractDatasetLoader[GenericParallelTextDataset] -): - @override - def _load(self, path: Path, card: AssetCard) -> GenericParallelTextDataset: - try: - return GenericParallelTextDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_parallel_text_dataset = GenericParallelTextDatasetLoader() - -load_parallel_text_dataset.register( - "generic_parallel_text", load_generic_parallel_text_dataset -) +def load_parallel_text_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> ParallelTextDataset: + return load_dataset(name_or_card, ParallelTextDataset, force=force) diff --git a/src/fairseq2/datasets/preference.py b/src/fairseq2/datasets/preference.py index fd23d3f92..66e8794ba 100644 --- a/src/fairseq2/datasets/preference.py +++ b/src/fairseq2/datasets/preference.py @@ -1,15 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from __future__ import annotations import json from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union, cast, final +from typing import Any, Final, cast, final import torch -from typing_extensions import NoReturn +from typing_extensions import override -from fairseq2.assets import AssetCard, AssetError +from fairseq2.assets import AssetCard from fairseq2.data import ( CollateOptionsOverride, Collater, @@ -20,45 +27,43 @@ read_sequence, ) from fairseq2.data.text import TextTokenizer -from fairseq2.datasets.batching import LengthBatching, StaticBatching +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) from fairseq2.datasets.data_reader import DataPipelineReader -from fairseq2.datasets.error import DatasetError -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.static import load_dataset +from fairseq2.datasets.utils import _load_files_and_weights +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.typing import override @dataclass class PreferenceOptimizationBatch: - """Represents a preference optimization batch.""" + """Represents a preference optimization dataset batch.""" chosen: SequenceBatch rejected: SequenceBatch + reference_score_chosen: torch.Tensor | None + reference_score_rejected: torch.Tensor | None class PreferenceOptimizationDataset(ABC): - """Represents an preference optimization finetuning dataset.""" + """Represents a preference optimization dataset.""" @abstractmethod def create_reader( self, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: PreferenceReadOptions | None = None, ) -> DataPipelineReader[PreferenceOptimizationBatch]: """Create a dataset reader. @@ -66,67 +71,68 @@ def create_reader( The tokenizer to encode text. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param sample: - If ``True``, instruction sources (e.g. files) will be sampled in - proportion to their weights. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. + :param options: + The read options. """ -load_preference_optimization_dataset = DelegatingDatasetLoader[ - PreferenceOptimizationDataset -]() +@dataclass +class PreferenceReadOptions(DataReadOptions): + sample: bool = False + """ + If ``True``, instruction sources (e.g. JSONL files) will be sampled in + proportion to their weights. + """ + + mask_source_tokens: bool = True + """ + If ``False``, calculates loss on the source tokens (prompt) as well as the + target tokens. + """ + + source_encode_mode: str = "prompt" + """The tokenizer mode to encode the source text.""" + + target_encode_mode: str = "prompt_response" + """The tokenizer mode to encode the target text.""" + # TODO: FIX, INFER npc = 10 +GENERIC_PREFERENCE_OPTIMIZATION_DATASET_FAMILY: Final = ( + "generic_preference_optimization" +) + + @final class GenericPreferenceOptimizationDataset(PreferenceOptimizationDataset): - """Represents a generic JSONL preferemce preference optimization dataset.""" + """Represents a generic JSONL preference optimization dataset.""" + _name: str _files: Sequence[Path] _weights: Sequence[float] - def __init__(self, files: Sequence[Path], weights: Sequence[float]) -> None: + def __init__( + self, name: str, files: Sequence[Path], weights: Sequence[float] + ) -> None: """ :param files: The instruction files. :param weights: The weight of each file in ``files``. """ + self._name = name + if len(files) != len(weights): raise ValueError( f"The lengths of `files` and `weights` must match, but they are {len(files)} and {len(weights)} instead." @@ -135,113 +141,32 @@ def __init__(self, files: Sequence[Path], weights: Sequence[float]) -> None: self._files = files self._weights = weights - @classmethod - def from_path(cls, path: Path) -> GenericPreferenceOptimizationDataset: - """Load a :class:`PreferenceOptimizationDataset` from ``path``.""" - path = path.expanduser().resolve() - - if not path.is_dir(): - return GenericPreferenceOptimizationDataset(files=[path], weights=[1.0]) - - manifest_file = path.joinpath("MANIFEST") - - try: - fp = manifest_file.open() - except FileNotFoundError: - fp = None - except OSError as ex: - raise RuntimeError( - f"{manifest_file} cannot be read. See nested exception for details." - ) from ex - - # If the directory does not contain a MANIFEST file, treat all JSONL - # files as part of the dataset with equal weight. - if fp is None: - try: - files = list(path.glob("**/*.jsonl")) - except OSError as ex: - raise RuntimeError( - f"The JSONL files under {path} cannot be retrieved. See nested exception for details." - ) from ex - - weights = [1.0 for _ in range(len(files))] - - return GenericPreferenceOptimizationDataset(files, weights=weights) - - try: - content = list(fp) - except OSError as ex: - raise RuntimeError( - f"{manifest_file} cannot be read. See nested exception for details." - ) from ex - finally: - fp.close() - - # Sort the JSONL files in alphabetical order. - content.sort() - - files = [] - - weights = [] - - # Each line of the MANIFEST file corresponds to the path of a JSONL file - # and its weight (e.g. number of examples). - for idx, line in enumerate(content): - - def raise_error() -> NoReturn: - raise DatasetError( - f"Each line in {manifest_file} must represent a path to a JSONL file and a weight, but line {idx} is '{line}' instead." - ) - - fields = line.rstrip().split("\t") - - if len(fields) != 2: - raise_error() - - file_path = fields[0].strip() - if not file_path: - raise_error() - - try: - file = path.joinpath(file_path) - except ValueError: - raise_error() - - if not file.exists(): - raise DatasetError( - f"The file '{file}' referred at line {idx} in {manifest_file} does not exist." - ) - - files.append(file) - - try: - weight = float(fields[1].strip()) - except ValueError: - raise_error() + @staticmethod + def from_path( + path: Path, name: str | None = None + ) -> GenericPreferenceOptimizationDataset: + if name is None: + name = f"path:{path.name}" - weights.append(weight) + files, weights = _load_files_and_weights(name, path) - return GenericPreferenceOptimizationDataset(files, weights) + return GenericPreferenceOptimizationDataset(name, files, weights) @override def create_reader( self, tokenizer: TextTokenizer, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - sample: bool = False, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: PreferenceReadOptions | None = None, ) -> DataPipelineReader[PreferenceOptimizationBatch]: + if options is None: + options = PreferenceReadOptions() + + seed = options.seed + if len(self._files) == 1: builder = self._read_jsonl(self._files[0], tokenizer) else: @@ -252,7 +177,7 @@ def create_reader( pipelines.append(pipeline) - if sample: + if options.sample: builder = DataPipeline.sample( pipelines, weights=self._weights, seed=seed ) @@ -262,79 +187,100 @@ def create_reader( builder = DataPipeline.concat(pipelines) # Shuffle files. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(shuffle_window=0, seed=seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed=seed) seed += 1 + # Shard. + builder.shard(gang.rank, gang.size, allow_uneven=True) + seed += gang.rank - # Encode prompt and target texts. - prompt_encoder = tokenizer.create_encoder(mode="prompt") - target_encoder = tokenizer.create_encoder(mode="prompt_response") + # Encode source and target texts. + source_encoder = tokenizer.create_encoder(mode=options.source_encode_mode) + target_encoder = tokenizer.create_encoder(mode=options.target_encode_mode) - builder.map(prompt_encoder, selector="src", num_parallel_calls=npc) + builder.map(source_encoder, selector="src", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt_chosen", num_parallel_calls=npc) builder.map(target_encoder, selector="tgt_rejected", num_parallel_calls=npc) - # Filter out long examples. - def skip(example: Dict[str, Any]) -> bool: - chosen_len = len(example["src"]) + len(example["tgt_chosen"]) - rejected_len = len(example["src"]) + len(example["tgt_rejected"]) - return chosen_len <= max_seq_len and rejected_len <= max_seq_len + def cat_source_and_target(example: dict[str, Any]) -> dict[str, Any]: + id_ = example.get("id", None) - builder.filter(skip) - - static_batching = isinstance(batching, StaticBatching) - - # Shard. - builder.shard(gang.rank, gang.size, allow_uneven=not static_batching) - - def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]: - sample_id = example.get("id", None) - - prompt_indices = example["src"] + source_indices = example["src"] target_indices_chosen = example["tgt_chosen"] target_indices_rejected = example["tgt_rejected"] - indices_chosen = torch.cat([prompt_indices, target_indices_chosen]) - indices_rejected = torch.cat([prompt_indices, target_indices_rejected]) + indices_chosen = torch.cat([source_indices, target_indices_chosen]) + indices_rejected = torch.cat([source_indices, target_indices_rejected]) - target_mask_chosen = torch.arange(len(indices_chosen)) >= len( - prompt_indices - ) - target_mask_rejected = torch.arange(len(indices_rejected)) >= len( - prompt_indices + if options.mask_source_tokens: + source_len = len(source_indices) + target_mask_chosen = torch.arange(len(indices_chosen)) >= source_len + target_mask_rejected = torch.arange(len(indices_rejected)) >= source_len + else: + target_mask_chosen = torch.full([len(indices_chosen)], True) + target_mask_rejected = torch.full([len(indices_rejected)], True) + + total_tokens = ( + 2 * len(source_indices) + + len(target_indices_chosen) + + len(target_indices_rejected) ) return { + "id": id_, + "indices_prompt": source_indices, "indices_chosen": indices_chosen, "indices_rejected": indices_rejected, + "reference_score_chosen": example.get("reference_score_chosen", None), + "reference_score_rejected": example.get( + "reference_score_rejected", None + ), "target_mask_chosen": target_mask_chosen, "target_mask_rejected": target_mask_rejected, - "id": sample_id, + "total_tokens": total_tokens, } builder.map(cat_source_and_target, num_parallel_calls=npc) if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( - max_seq_len=max_seq_len, max_num_elements=batching.max_num_elements + min_seq_len=min_seq_len, + max_seq_len=max_seq_len, + max_num_elements=batching.max_num_elements, ) # Bucket by the sequence length. builder.bucket_by_length( bucket_sizes, - selector="indices_chosen,indices_rejected", + selector="total_tokens", + min_data_len=min_seq_len, skip_above_max_examples=True, + drop_remainder=options.drop_remainder, ) - else: + elif isinstance(batching, StaticBatching): + # Filter out long examples. + def skip(example: dict[str, Any]) -> bool: + chosen_len = len(example["indices_chosen"]) + rejected_len = len(example["indices_rejected"]) + + if chosen_len > max_seq_len or rejected_len > max_seq_len: + return False + + return chosen_len >= min_seq_len and rejected_len >= min_seq_len + + builder.filter(skip) + # Bucket `batch_size` examples. - builder.bucket(batching.batch_size) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed=seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed=seed) seed += 1 @@ -349,14 +295,14 @@ def cat_source_and_target(example: Dict[str, Any]) -> Dict[str, Any]: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) - # Wrap examples with `SequenceBatch`. - def to_batch(example: Dict[str, Any]) -> PreferenceOptimizationBatch: + # Wrap examples with `PreferenceOptimizationBatch`. + def to_batch(example: dict[str, Any]) -> PreferenceOptimizationBatch: indices_chosen = cast(SequenceData, example["indices_chosen"]) indices_rejected = cast(SequenceData, example["indices_rejected"]) @@ -368,13 +314,15 @@ def to_batch(example: Dict[str, Any]) -> PreferenceOptimizationBatch: ) target_mask_chosen = example["target_mask_chosen"]["seqs"].to(gang.device) - target_mask_rejected = example["target_mask_rejected"]["seqs"].to( - gang.device - ) + target_mask_rejected = example["target_mask_rejected"]["seqs"].to(gang.device) # fmt: skip batch_chosen = SequenceBatch( - seqs_chosen, padding_mask_chosen, target_mask_chosen, example=example + seqs_chosen, + padding_mask_chosen, + target_mask_chosen, + example=example, ) + batch_rejected = SequenceBatch( seqs_rejected, padding_mask_rejected, @@ -382,18 +330,33 @@ def to_batch(example: Dict[str, Any]) -> PreferenceOptimizationBatch: example=example, ) + batch_reference_scores_chosen = None + if all(example["reference_score_chosen"]): + batch_reference_scores_chosen = torch.Tensor( + example["reference_score_chosen"] + ).to(gang.device) + batch_reference_scores_rejected = None + if all(example["reference_score_rejected"]): + batch_reference_scores_rejected = torch.Tensor( + example["reference_score_rejected"] + ).to(gang.device) + return PreferenceOptimizationBatch( - chosen=batch_chosen, rejected=batch_rejected + batch_chosen, + batch_rejected, + batch_reference_scores_chosen, + batch_reference_scores_rejected, ) pipeline = builder.map(to_batch).and_return() return DataPipelineReader[PreferenceOptimizationBatch]( + self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, ) def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuilder: @@ -407,26 +370,7 @@ def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuild return read_sequence(lines).map(json.loads, num_parallel_calls=npc) -@final -class GenericPreferenceOptimizationDatasetLoader( - AbstractDatasetLoader[GenericPreferenceOptimizationDataset] -): - @override - def _load( - self, path: Path, card: AssetCard - ) -> GenericPreferenceOptimizationDataset: - try: - return GenericPreferenceOptimizationDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_preference_optimization_dataset = ( - GenericPreferenceOptimizationDatasetLoader() -) - -load_preference_optimization_dataset.register( - "generic_preference_optimization", load_generic_preference_optimization_dataset -) +def load_preference_optimization_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> PreferenceOptimizationDataset: + return load_dataset(name_or_card, PreferenceOptimizationDataset, force=force) diff --git a/src/fairseq2/datasets/speech.py b/src/fairseq2/datasets/speech.py new file mode 100644 index 000000000..66e891b09 --- /dev/null +++ b/src/fairseq2/datasets/speech.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Final, final + +import torch +from typing_extensions import override + +from fairseq2.assets import AssetCard +from fairseq2.datasets.config import Batching, DataReadOptions +from fairseq2.datasets.data_reader import DataPipelineReader, DataReader +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError +from fairseq2.gang import Gang +from fairseq2.models.sequence import SequenceBatch +from fairseq2.typing import DataType + + +class SpeechDataset(ABC): + """Represents a speech dataset.""" + + @abstractmethod + def create_reader( + self, + split: str, + gang: Gang, + min_audio_len: int, + max_audio_len: int, + batching: Batching, + options: SpeechReadOptions | None = None, + ) -> DataReader[SequenceBatch]: + """Create a dataset reader. + + :param split: + The split to read. + :param gang: + The gang over which to shard the dataset. + :param min_audio_len: + The minimum audio length of each example. Examples shorter than this + value will be dropped. + :param max_audio_len: + The maximum audio length of each example. Examples longer than this + value will be dropped. + :param batching: + The batching strategy for returned examples. + :param options: + The read options. + """ + + @abstractmethod + def splits(self) -> set[str]: + """Return the set of splits.""" + + +@dataclass +class SpeechReadOptions(DataReadOptions): + dtype: DataType = torch.float32 + """The data type of the decoded audio sequences.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" + + +GENERIC_SPEECH_DATASET_FAMILY: Final = "generic_speech" + + +@final +class GenericSpeechDataset(SpeechDataset): + """Represents a generic manifest-based Speech dataset.""" + + @staticmethod + def from_path(path: Path, name: str | None = None) -> GenericSpeechDataset: + return GenericSpeechDataset() + + @override + def create_reader( + self, + split: str, + gang: Gang, + min_audio_len: int, + max_audio_len: int, + batching: Batching, + options: SpeechReadOptions | None = None, + ) -> DataPipelineReader[SequenceBatch]: + raise NotSupportedError("not supported yet.") + + @override + def splits(self) -> set[str]: + return set() + + +def load_speech_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> SpeechDataset: + return load_dataset(name_or_card, SpeechDataset, force=force) diff --git a/src/fairseq2/datasets/static.py b/src/fairseq2/datasets/static.py new file mode 100644 index 000000000..392f03d44 --- /dev/null +++ b/src/fairseq2/datasets/static.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import TypeVar + +from fairseq2.assets import AssetCard +from fairseq2.context import get_runtime_context +from fairseq2.datasets.handler import ( + DatasetHandler, + DatasetNotFoundError, + get_dataset_family, +) +from fairseq2.error import ContractError + +DatasetT = TypeVar("DatasetT") + + +def load_dataset( + name_or_card: str | AssetCard, kls: type[DatasetT], *, force: bool = False +) -> DatasetT: + context = get_runtime_context() + + if isinstance(name_or_card, AssetCard): + card = name_or_card + else: + card = context.asset_store.retrieve_card(name_or_card) + + family = get_dataset_family(card) + + registry = context.get_registry(DatasetHandler) + + try: + handler = registry.get(family) + except LookupError: + raise DatasetNotFoundError(card.name) from None + + if not issubclass(handler.kls, kls): + raise TypeError( + f"The dataset is expected to be of type `{kls}`, but is of type `{type(handler.kls)}` instead." + ) + + dataset = handler.load(card, force=force) + + if not isinstance(dataset, kls): + raise ContractError( + f"The dataset is expected to be of type `{kls}`, but is of type `{type(handler.kls)}` instead." + ) + + return dataset diff --git a/src/fairseq2/datasets/text.py b/src/fairseq2/datasets/text.py index 566f801ad..4a5fb8457 100644 --- a/src/fairseq2/datasets/text.py +++ b/src/fairseq2/datasets/text.py @@ -7,11 +7,15 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union, cast, final +from typing import Any, Final, cast, final -from fairseq2.assets import AssetCard, AssetError +from typing_extensions import override + +from fairseq2.assets import AssetCard from fairseq2.data import ( Collater, DataPipeline, @@ -20,13 +24,20 @@ read_sequence, ) from fairseq2.data.text import TextTokenEncoder, read_text -from fairseq2.datasets.batching import LengthBatching, StaticBatching +from fairseq2.datasets.config import ( + Batching, + DataReadOptions, + LengthBatching, + StaticBatching, +) from fairseq2.datasets.data_reader import DataPipelineReader, DataReader -from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader +from fairseq2.datasets.error import DatasetError +from fairseq2.datasets.static import load_dataset +from fairseq2.error import NotSupportedError from fairseq2.gang import Gang from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.padding import get_seqs_and_padding_mask -from fairseq2.typing import Device, override +from fairseq2.typing import Device class TextDataset(ABC): @@ -36,21 +47,12 @@ class TextDataset(ABC): def create_reader( self, text_encoder: TextTokenEncoder, - pad_idx: Optional[int], + pad_idx: int | None, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - min_seq_len: int = 1, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: TextReadOptions | None = None, ) -> DataReader[SequenceBatch]: """Create a dataset reader. @@ -60,68 +62,51 @@ def create_reader( The index of the PAD symbol in the vocabulary of ``text_encoder``. :param gang: The gang over which to shard the dataset. + :param min_seq_len: + The minimum sequence length of each example. Examples shorter than + this value will be dropped. :param max_seq_len: The maximum sequence length of each example. Examples longer than this value will be dropped. :param batching: The batching strategy for returned examples. - :param min_seq_len: - The minimum sequence length of each example. Examples shorter than - this value will be dropped. - :param example_shuffle_window: - The size of the sliding window for shuffling examples. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param batch_shuffle_window: - The size of the sliding window for shuffling batches. If ``1``, no - shuffling is performed; if ``0``, true shuffling is performed by - loading the entire dataset. - :param drop_remainder: - If ``True``, drops the last set of batches if they have in total - fewer examples than requested. - :param sync_batches: - If ``True``, ensures that each process in ``gang`` reads the same - number of batches. Typically used when the amount of data to be read - can vary per process (e.g. due to unbalanced sharding or non-static - batching) and it is critical for each process to iterate over the - same number of batches (e.g. during training). - :param max_num_batches: - The maximum number of batches to return. - :param num_accumulate: - The number of batches to accumulate in each iteration. Typically - used with gradient accumulation during training. - :param num_prefetch: - The number of batches to prefetch in background. - :param seed: - The seed to initialize the random number generators used internally. - :param extras: - The extra parameters specific to the dataset implementation. + :param options: + The read options. """ -load_text_dataset = DelegatingDatasetLoader[TextDataset]() +@dataclass +class TextReadOptions(DataReadOptions): + pass # TODO: FIX, INFER npc = 10 +GENERIC_TEXT_DATASET_FAMILY: Final = "generic_text" + + @final class GenericTextDataset(TextDataset): """Represents a generic file-based text dataset.""" + _name: str _files: Sequence[Path] - def __init__(self, files: Sequence[Path]) -> None: + def __init__(self, name: str, files: Sequence[Path]) -> None: """ :param data_dir: The list of text files that represent the dataset. """ + self._name = name self._files = files @staticmethod - def from_path(path: Path) -> GenericTextDataset: - """Load a :class:`GenericTextDataset` from ``path``.""" + def from_path(path: Path, name: str | None = None) -> GenericTextDataset: + if name is None: + name = f"path:{path.name}" + path = path.expanduser().resolve() if not path.is_dir(): @@ -130,34 +115,30 @@ def from_path(path: Path) -> GenericTextDataset: try: files = [f for f in path.glob("**/*.txt") if not f.is_dir()] except OSError as ex: - raise RuntimeError( - f"The text files under {path} cannot be retrieved. See nested exception for details." + raise DatasetError( + name, f"The text files under the '{path}' directory cannot be retrieved. See the nested exception for details." # fmt: skip ) from ex files.sort() - return GenericTextDataset(files) + return GenericTextDataset(name, files) @override def create_reader( self, text_encoder: TextTokenEncoder, - pad_idx: Optional[int], + pad_idx: int | None, gang: Gang, + min_seq_len: int, max_seq_len: int, - batching: Union[StaticBatching, LengthBatching], - *, - min_seq_len: int = 1, - example_shuffle_window: int = 1, - batch_shuffle_window: int = 1, - drop_remainder: bool = False, - sync_batches: bool = True, - max_num_batches: Optional[int] = None, - num_accumulate: int = 1, - num_prefetch: int = 1, - seed: int = 2, - **extras: Any, + batching: Batching, + options: TextReadOptions | None = None, ) -> DataReader[SequenceBatch]: + if options is None: + options = TextReadOptions() + + seed = options.seed + if len(self._files) == 1: builder = read_text(self._files[0], key="text", rtrim=True) else: @@ -169,8 +150,8 @@ def read_file(file: Path) -> DataPipeline: builder.yield_from(read_file) # Shuffle examples. Must be consistent across all processes. - if example_shuffle_window != 1: - builder.shuffle(example_shuffle_window, seed) + if options.example_shuffle_window != 1: + builder.shuffle(options.example_shuffle_window, seed) seed += 1 @@ -179,7 +160,7 @@ def read_file(file: Path) -> DataPipeline: seed += gang.rank - def encode(example: Dict[str, Any]) -> Dict[str, Any]: + def encode(example: dict[str, Any]) -> dict[str, Any]: example["indices"] = text_encoder(example["text"]) return example @@ -200,11 +181,11 @@ def encode(example: Dict[str, Any]) -> Dict[str, Any]: min_data_len=min_seq_len, skip_below_min_examples=True, skip_above_max_examples=True, - drop_remainder=drop_remainder, + drop_remainder=options.drop_remainder, ) - else: + elif isinstance(batching, StaticBatching): # Filter out out-of-range examples. - def skip(example: Dict[str, Any]) -> bool: + def skip(example: dict[str, Any]) -> bool: seq_len = len(example["indices"]) return seq_len >= min_seq_len and seq_len <= max_seq_len @@ -212,11 +193,13 @@ def skip(example: Dict[str, Any]) -> bool: builder.filter(skip) # Bucket `batch_size` examples. - builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + builder.bucket(batching.batch_size, drop_remainder=options.drop_remainder) + else: + raise NotSupportedError(f"`{batching}` is not supported.") # Shuffle buckets. - if batch_shuffle_window != 1: - builder.shuffle(batch_shuffle_window, seed) + if options.batch_shuffle_window != 1: + builder.shuffle(options.batch_shuffle_window, seed) seed += 1 @@ -226,26 +209,28 @@ def skip(example: Dict[str, Any]) -> bool: builder.map(collater, num_parallel_calls=npc) # Return only the first `max_num_batches`. - if max_num_batches is not None: - builder.take(max_num_batches) + if options.max_num_batches is not None: + builder.take(options.max_num_batches) # Prefetch `num_prefetch` batches in background. - builder.prefetch(num_prefetch) + builder.prefetch(options.num_prefetch) f = partial(self._to_batch, device=gang.device) pipeline = builder.map(f).and_return() return DataPipelineReader[SequenceBatch]( + self._name, pipeline, gang, - num_accumulate=num_accumulate, - drop_remainder=drop_remainder, - sync_batches=sync_batches, + num_accumulate=options.num_accumulate, + drop_remainder=options.drop_remainder, + sync_batches=options.sync_batches, + sync_mode=options.sync_mode, ) @staticmethod - def _to_batch(example: Dict[str, Any], device: Device) -> SequenceBatch: + def _to_batch(example: dict[str, Any], device: Device) -> SequenceBatch: data = cast(SequenceData, example["indices"]) seqs, padding_mask = get_seqs_and_padding_mask(data, device) @@ -253,18 +238,7 @@ def _to_batch(example: Dict[str, Any], device: Device) -> SequenceBatch: return SequenceBatch(seqs, padding_mask, example=example) -@final -class GenericTextDatasetLoader(AbstractDatasetLoader[GenericTextDataset]): - @override - def _load(self, path: Path, card: AssetCard) -> GenericTextDataset: - try: - return GenericTextDataset.from_path(path) - except RuntimeError as ex: - raise AssetError( - f"{card.name} cannot be loaded. See nested exception for details." - ) from ex - - -load_generic_text_dataset = GenericTextDatasetLoader() - -load_text_dataset.register("generic_text", load_generic_text_dataset) +def load_text_dataset( + name_or_card: str | AssetCard, *, force: bool = False +) -> TextDataset: + return load_dataset(name_or_card, TextDataset, force=force) diff --git a/src/fairseq2/datasets/utils.py b/src/fairseq2/datasets/utils.py index 7d8e6d356..8f2a0b526 100644 --- a/src/fairseq2/datasets/utils.py +++ b/src/fairseq2/datasets/utils.py @@ -6,18 +6,21 @@ from __future__ import annotations +from pathlib import Path + import torch -from fairseq2.gang import Gang -from fairseq2.logging import LogWriter +from fairseq2.datasets.error import DatasetError +from fairseq2.gang import Gang, all_sum +from fairseq2.logging import log -def _reduce_num_batches(num_batches: int, gang: Gang, log: LogWriter) -> int: +def _min_num_batches(num_batches: int, gang: Gang) -> int: all_num_batches = torch.zeros((gang.size,), device=gang.device, dtype=torch.int64) - num_batches_ = torch.tensor(num_batches, device=gang.device) + inp = torch.tensor(num_batches, device=gang.device) - gang.all_gather(all_num_batches, num_batches_) + gang.all_gather(all_num_batches, inp) min_num_batches = int(all_num_batches.min()) if min_num_batches != 0: @@ -33,3 +36,88 @@ def _reduce_num_batches(num_batches: int, gang: Gang, log: LogWriter) -> int: log.debug("End of data reached at rank(s) {}.", s) return 0 + + +def _sum_num_batches(num_batches: int, gang: Gang) -> int: + total_num_batches = all_sum(gang, num_batches) + + return int(total_num_batches) + + +def _load_files_and_weights(name: str, path: Path) -> tuple[list[Path], list[float]]: + path = path.expanduser().resolve() + + if not path.is_dir(): + return [path], [1.0] + + manifest_file = path.joinpath("MANIFEST") + + try: + with manifest_file.open() as fp: + content = list(fp) + except FileNotFoundError: + content = None + except OSError as ex: + raise DatasetError( + name, f"The '{manifest_file}' manifest file cannot be read. See the nested exception for details." # fmt: skip + ) from ex + + # If the directory does not contain a MANIFEST file, treat all JSONL + # files as part of the dataset with equal weight. + if content is None: + try: + files = list(path.glob("**/*.jsonl")) + except OSError as ex: + raise DatasetError( + name, f"The JSONL files under the '{path}' directory cannot be retrieved. See the nested exception for details." # fmt: skip + ) from ex + + weights = [1.0 for _ in range(len(files))] + + return files, weights + + # Sort the JSONL files in alphabetical order. + content.sort() + + files = [] + + weights = [] + + # Each line of the MANIFEST file corresponds to the path of a JSONL file + # and its weight (e.g. number of examples). + for idx, line in enumerate(content): + + def error() -> DatasetError: + return DatasetError( + name, f"Each line in the '{manifest_file}' manifest file must represent a path to a JSONL file and a weight, but line {idx} is '{line}' instead." # fmt: skip + ) + + fields = line.rstrip().split("\t") + + if len(fields) != 2: + raise error() + + file_path = fields[0].strip() + if not file_path: + raise error() + + try: + file = path.joinpath(file_path) + except ValueError: + raise error() from None + + if not file.exists(): + raise DatasetError( + name, f"The '{file}' file referred at line {idx} in the '{manifest_file}' manifest file does not exist." # fmt: skip + ) + + files.append(file) + + try: + weight = float(fields[1].strip()) + except ValueError: + raise error() from None + + weights.append(weight) + + return files, weights diff --git a/src/fairseq2/device.py b/src/fairseq2/device.py index 546408c5b..46b8a2e41 100644 --- a/src/fairseq2/device.py +++ b/src/fairseq2/device.py @@ -7,18 +7,19 @@ from __future__ import annotations import os -from typing import Optional import torch -from fairseq2.logging import get_log_writer +from fairseq2.error import InternalError +from fairseq2.logging import log from fairseq2.typing import CPU, Device -from fairseq2.utils.env import get_int_from_env +from fairseq2.utils.env import ( + InvalidEnvironmentVariableError, + get_device_from_env, + get_int_from_env, +) -log = get_log_writer(__name__) - - -_default_device: Optional[Device] = None +_default_device: Device | None = None def determine_default_device() -> Device: @@ -28,14 +29,12 @@ def determine_default_device() -> Device: if _default_device is not None: return _default_device - device_str = os.environ.get("FAIRSEQ2_DEVICE") - if device_str is not None: - try: - _default_device = Device(device_str) - except RuntimeError as ex: - raise RuntimeError( - f"The value of the `FAIRSEQ2_DEVICE` environment variable must specify a valid PyTorch device, but is '{device_str}' instead." - ) from ex + try: + _default_device = get_device_from_env("FAIRSEQ2_DEVICE") + except InvalidEnvironmentVariableError as ex: + raise DeviceDetectionError( + "The default device cannot be set using the `FAIRSEQ2_DEVICE` environment variable. See the nested exception for details." + ) from ex if _default_device is None: _default_device = determine_default_cuda_device() @@ -51,7 +50,7 @@ def determine_default_device() -> Device: return _default_device -def determine_default_cuda_device() -> Optional[Device]: +def determine_default_cuda_device() -> Device | None: """Determine the default CUDA ``torch.device`` of the process.""" if not torch.cuda.is_available(): return None @@ -74,7 +73,12 @@ def determine_default_cuda_device() -> Optional[Device]: device = None if device is None: - idx = _get_device_index(num_devices, device_type="cuda") + try: + idx = _get_device_index(num_devices, device_type="cuda") + except InvalidEnvironmentVariableError as ex: + raise DeviceDetectionError( + "The default `cuda` device index cannot be inferred from the environment. See the nested exception for details." + ) from ex device = Device("cuda", index=idx) @@ -82,7 +86,8 @@ def determine_default_cuda_device() -> Optional[Device]: def _get_device_index(num_devices: int, device_type: str) -> int: - assert num_devices > 0 + if num_devices <= 0: + raise InternalError(f"`num_devices` is {num_devices}.") # We use the `LOCAL_RANK` environment variable to determine which device to # pick in case the process has more than one available. @@ -90,20 +95,24 @@ def _get_device_index(num_devices: int, device_type: str) -> int: if device_idx is None: num_procs = get_int_from_env("LOCAL_WORLD_SIZE") if num_procs is not None and num_procs > 1 and num_devices > 1: - raise RuntimeError( + raise InvalidEnvironmentVariableError( f"The default `{device_type}` device cannot be determined. There are {num_devices} devices available, but the `LOCAL_RANK` environment variable is not set." ) return 0 if device_idx < 0: - raise RuntimeError( - f"The value of the `LOCAL_RANK` environment variable must be greater than or equal to 0, but is {device_idx} instead." + raise InvalidEnvironmentVariableError( + f"The value of the `LOCAL_RANK` environment variable is expected to be greater than or equal to 0, but is {device_idx} instead." ) if device_idx >= num_devices: - raise RuntimeError( - f"The value of the `LOCAL_RANK` environment variable must be less than the number of available `{device_type}` devices ({num_devices}), but is {device_idx} instead." + raise InvalidEnvironmentVariableError( + f"The value of the `LOCAL_RANK` environment variable is expected to be less than the number of available `{device_type}` devices ({num_devices}), but is {device_idx} instead." ) return device_idx + + +class DeviceDetectionError(Exception): + pass diff --git a/src/fairseq2/early_stopper.py b/src/fairseq2/early_stopper.py deleted file mode 100644 index 41f3e39c9..000000000 --- a/src/fairseq2/early_stopper.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from typing import Protocol - - -class EarlyStopper(Protocol): - """Stops training when an implementation-specific condition is not met.""" - - def __call__(self, step_nr: int, score: float) -> bool: - """ - :param step_nr: - The number of the current training step. - :para score: - The validation score of the current training step. - - :returns: - ``True`` if the training should be stopped; otherwise, ``False``. - """ diff --git a/src/fairseq2/error.py b/src/fairseq2/error.py new file mode 100644 index 000000000..4b7accb1d --- /dev/null +++ b/src/fairseq2/error.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + + +class InternalError(Exception): + pass + + +class ContractError(Exception): + pass + + +class InvalidOperationError(Exception): + pass + + +class AlreadyExistsError(Exception): + pass + + +class NotSupportedError(Exception): + pass + + +class SetupError(Exception): + pass diff --git a/src/fairseq2/extensions.py b/src/fairseq2/extensions.py new file mode 100644 index 000000000..9b1683d3b --- /dev/null +++ b/src/fairseq2/extensions.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +from typing import Any + +from importlib_metadata import entry_points + +from fairseq2.logging import log + + +def run_extensions(extension_name: str, *args: Any, **kwargs: Any) -> None: + should_trace = "FAIRSEQ2_EXTENSION_TRACE" in os.environ + + for entry_point in entry_points(group=extension_name): + try: + extension = entry_point.load() + + extension(*args, **kwargs) + except TypeError: + if should_trace: + raise ExtensionError( + entry_point.value, f"The '{entry_point.value}' entry point is not a valid extension function." # fmt: skip + ) from None + + log.warning("The '{}' entry point is not a valid extension function. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip + except Exception as ex: + if should_trace: + raise ExtensionError( + entry_point.value, f"The '{entry_point.value}' extension function has failed. See the nested exception for details." # fmt: skip + ) from ex + + log.warning("The '{}' extension function has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip + + if should_trace: + log.info("The `{}` extension function run successfully.", entry_point.value) # fmt: skip + + +class ExtensionError(Exception): + _entry_point: str + + def __init__(self, entry_point: str, message: str) -> None: + super().__init__(message) + + self._entry_point = entry_point + + @property + def entry_point(self) -> str: + return self._entry_point diff --git a/src/fairseq2/factory_registry.py b/src/fairseq2/factory_registry.py new file mode 100644 index 000000000..da902eb97 --- /dev/null +++ b/src/fairseq2/factory_registry.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import is_dataclass +from functools import partial +from inspect import isfunction +from typing import ( + Any, + Generic, + ParamSpec, + Protocol, + TypeVar, + cast, + final, + get_type_hints, +) + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.typing import DataClass +from fairseq2.utils.dataclass import merge_dataclass +from fairseq2.utils.structured import ValueConverter, default_value_converter + +ConfigT = TypeVar("ConfigT", bound=DataClass) + +ConfigT_contra = TypeVar("ConfigT_contra", bound=DataClass, contravariant=True) + +P = ParamSpec("P") + +R = TypeVar("R") + +R_co = TypeVar("R_co", covariant=True) + + +class Factory(Protocol[ConfigT_contra, P, R_co]): + def __call__( + self, config: ConfigT_contra, *args: P.args, **kwargs: P.kwargs + ) -> R_co: + ... + + +class ConfigBoundFactory(Protocol[P, R_co]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co: + ... + + @property + def config(self) -> DataClass: + ... + + +@final +class ConfigBoundFactoryRegistry(Generic[P, R]): + """Holds factories with parameter(s) ``P`` and return type ``R``.""" + + _factories: dict[ + str, tuple[Callable[..., R], type[DataClass], ConfigRegistry[Any] | None] + ] + + _value_converter: ValueConverter + + def __init__(self, value_converter: ValueConverter | None = None) -> None: + self._factories = {} + + self._value_converter = value_converter or default_value_converter + + def get( + self, + name: str, + unstructured_config: object = None, + base_config_name: str | None = None, + *, + set_empty: bool = False, + ) -> ConfigBoundFactory[P, R]: + """Return the factory with ``name``. + + :param config: + The configuration to bind to the factory. + :param base_config_name: + The name of the configuration on which ``config`` will be based. + """ + try: + factory, config_kls, config_registry = self._factories[name] + except KeyError: + raise ValueError( + f"`name` must be a registered name, but '{name}' is not registered." + ) from None + + if unstructured_config is None: + config = None + else: + config = self._value_converter.structure( + unstructured_config, config_kls, set_empty=set_empty + ) + + if base_config_name is None: + if config is None: + try: + config = config_kls() + except TypeError as ex: + raise RuntimeError( + f"'{name}' has no default configuration." + ) from ex + else: + if config_registry is None: + raise ValueError( + f"`base_config_name` must be a registered configuration name, but is '{base_config_name}' instead." + ) + + try: + base_config = config_registry.get(base_config_name) + except ValueError: + raise ValueError( + f"`base_config_name` must be a registered configuration name, but is '{base_config_name}' instead." + ) from None + + if config is None: + config = base_config + else: + config = merge_dataclass(base_config, config) + + f = partial(factory, config) + + f.config = config # type: ignore[attr-defined] + + return cast(ConfigBoundFactory[P, R], f) + + def register( + self, + name: str, + factory: Factory[ConfigT, P, R], + config_kls: type[ConfigT], + config_registry: ConfigRegistry[ConfigT] | None = None, + ) -> None: + """Register ``factory`` with ``name``.""" + if name in self._factories: + raise ValueError( + f"`name` must be a unique name, but '{name}' is already registered." + ) + + self._factories[name] = (factory, config_kls, config_registry) + + def decorator( + self, name: str + ) -> Callable[[Factory[ConfigT, P, R]], Factory[ConfigT, P, R]]: + """Register ``name`` with the decorated factory function.""" + + def register(factory: Factory[ConfigT, P, R]) -> Factory[ConfigT, P, R]: + if not isfunction(factory): + raise TypeError("`factory` must be a function.") + + type_hints = get_type_hints(factory) + + if len(type_hints) < 2: + raise ValueError( + f"The decorated factory `{factory}` must have at least one parameter." + ) + + config_kls = next(iter(type_hints.values())) + + if not is_dataclass(config_kls): + raise ValueError( + f"The first parameter of the decorated factory `{factory}` must be a dataclass." + ) + + self.register(name, factory, config_kls) # type: ignore[arg-type] + + return factory + + return register diff --git a/src/fairseq2/gang.py b/src/fairseq2/gang.py index 7df62344e..570e30ef7 100644 --- a/src/fairseq2/gang.py +++ b/src/fairseq2/gang.py @@ -8,22 +8,23 @@ import os from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Union, final +from typing import final import torch import torch.distributed as dist from torch import Tensor from torch.distributed import Backend, ProcessGroup, ReduceOp +from typing_extensions import override from fairseq2.device import determine_default_cuda_device, determine_default_device -from fairseq2.logging import get_log_writer -from fairseq2.typing import CPU, Device, override -from fairseq2.utils.env import get_int_from_env -from fairseq2.utils.version import torch_greater_or_equal - -log = get_log_writer(__name__) +from fairseq2.error import InternalError, InvalidOperationError, NotSupportedError +from fairseq2.logging import log +from fairseq2.typing import CPU, Device +from fairseq2.utils.env import InvalidEnvironmentVariableError, get_int_from_env class ReduceOperation(Enum): @@ -44,8 +45,8 @@ def close(self) -> None: """Close and destroy the gang.""" @abstractmethod - def create_gang(self, ranks: Sequence[int]) -> Optional[Gang]: - """Create a new gang. + def make_gang(self, ranks: Sequence[int]) -> Gang | None: + """Make a new gang. :param ranks: The ranks of processes that will be part of the new gang. @@ -81,7 +82,7 @@ def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: @abstractmethod def all_gather_to_list( - self, output_tensors: List[Tensor], input_tensor: Tensor + self, output_tensors: list[Tensor], input_tensor: Tensor ) -> None: """Gather tensors from all processes and put them in ``output_tensors``. @@ -102,7 +103,7 @@ def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: """ @abstractmethod - def broadcast_objects(self, objects: List[Any], source_rank: int = 0) -> None: + def broadcast_objects(self, objects: list[object], source_rank: int = 0) -> None: """Broadcast picklable ``objects`` from ``source_rank`` to all processes. :param objects: @@ -128,6 +129,10 @@ def device(self) -> Device: """The associated device.""" +class GangError(Exception): + pass + + class AbstractGang(Gang): """Provides a skeletal implementation of :class:`Gang`.""" @@ -159,7 +164,7 @@ def __init__(self, rank: int, size: int, device: Device) -> None: @final @override - def create_gang(self, ranks: Sequence[int]) -> Optional[Gang]: + def make_gang(self, ranks: Sequence[int]) -> Gang | None: if len(set(ranks)) != len(ranks): raise ValueError("The ranks in ``ranks`` must be all unique.") @@ -169,11 +174,11 @@ def create_gang(self, ranks: Sequence[int]) -> Optional[Gang]: f"The rank at index {idx} in ``ranks`` must be greater than or equal to 0 and less than the size of the gang ({self._size}), but is {rank} instead." ) - return self._do_create_gang(ranks) + return self._do_make_gang(ranks) @abstractmethod - def _do_create_gang(self, ranks: Sequence[int]) -> Optional[Gang]: - """Create a new gang. + def _do_make_gang(self, ranks: Sequence[int]) -> Gang | None: + """Make a new gang. :param ranks: The ranks of processes that will be part of the new gang. @@ -203,7 +208,7 @@ class FakeGang(AbstractGang): """Represents a non-distributed gang for local use.""" def __init__( - self, *, rank: int = 0, size: int = 1, device: Optional[Device] = None + self, *, rank: int = 0, size: int = 1, device: Device | None = None ) -> None: """ :param rank: @@ -224,7 +229,7 @@ def close(self) -> None: pass @override - def _do_create_gang(self, ranks: Sequence[int]) -> Optional[FakeGang]: + def _do_make_gang(self, ranks: Sequence[int]) -> FakeGang | None: try: idx = ranks.index(self._rank) except ValueError: @@ -234,7 +239,9 @@ def _do_create_gang(self, ranks: Sequence[int]) -> Optional[FakeGang]: @override def as_process_group(self) -> ProcessGroup: - raise RuntimeError("`FakeGang` does not support conversion to a process group.") + raise NotSupportedError( + "`FakeGang` does not support conversion to a process group." + ) @override def barrier(self) -> None: @@ -242,10 +249,15 @@ def barrier(self) -> None: @override def all_reduce(self, tensor: Tensor, op: ReduceOperation) -> None: - if op == ReduceOperation.SUM: - tensor *= self._size - elif op == ReduceOperation.PRODUCT: - tensor.pow_(self._size) + match op: + case ReduceOperation.SUM: + tensor *= self._size + case ReduceOperation.PRODUCT: + tensor.pow_(self._size) + case _: + raise NotSupportedError( + "`FakeGang` supports only `SUM` and `PRODUCT` reduce operations." + ) @override def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: @@ -259,7 +271,7 @@ def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: if output_tensor.size(0) != self._size: raise ValueError( - f"The size of the first dimension of `output_tensor` must match the size of the gang ({self._size}), but is {output_tensor.size(0)} instead." + f"The size of the first dimension of `output_tensor` must match the number of processes in the gang ({self._size}), but is {output_tensor.size(0)} instead." ) for i in range(self._size): @@ -267,11 +279,11 @@ def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: @override def all_gather_to_list( - self, output_tensors: List[Tensor], input_tensor: Tensor + self, output_tensors: list[Tensor], input_tensor: Tensor ) -> None: if len(output_tensors) != self._size: raise ValueError( - f"The length of `output_tensors` must match the size of the gang ({self._size}), but is {len(output_tensors)} instead." + f"The length of `output_tensors` must match the number of processes in the gang ({self._size}), but is {len(output_tensors)} instead." ) for i in range(self._size): @@ -285,7 +297,7 @@ def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: ) @override - def broadcast_objects(self, objects: List[Any], source_rank: int = 0) -> None: + def broadcast_objects(self, objects: list[object], source_rank: int = 0) -> None: if source_rank != self._rank: raise ValueError( f"`source_rank` must be {self._rank}, but is {source_rank} instead." @@ -296,17 +308,17 @@ def broadcast_objects(self, objects: List[Any], source_rank: int = 0) -> None: class ProcessGroupGang(AbstractGang): """Represents a gang that wraps a process group.""" - _default: Optional[ProcessGroupGang] = None + _default: ProcessGroupGang | None = None _pg: ProcessGroup - _monitor_pg: Optional[ProcessGroup] + _monitor_pg: ProcessGroup | None def __init__( self, pg: ProcessGroup, device: Device, *, - monitor_pg: Optional[ProcessGroup] = None, + monitor_pg: ProcessGroup | None = None, ) -> None: super().__init__(dist.get_rank(pg), dist.get_world_size(pg), device) @@ -317,9 +329,9 @@ def __init__( def init_default_process_group( cls, *, - device: Optional[Device] = None, - timeout: Optional[timedelta] = None, - num_threads: Optional[int] = None, + device: Device | None = None, + timeout: timedelta | None = None, + num_threads: int | None = None, monitored: bool = False, ok_initialized: bool = False, ) -> ProcessGroupGang: @@ -345,7 +357,7 @@ def init_default_process_group( dist.set_debug_level_from_env() if not dist.is_available(): - raise RuntimeError("`torch.distributed` is not available.") + raise GangError("`torch.distributed` is not available.") if dist.is_initialized(): if ok_initialized: @@ -353,9 +365,14 @@ def init_default_process_group( return ProcessGroupGang.from_default_process_group() - raise RuntimeError("The default process group is already initialized.") + raise GangError("The default process group is already initialized.") - num_procs = get_local_world_size() + try: + num_procs = get_local_world_size() + except InvalidEnvironmentVariableError as ex: + raise GangError( + "The local world size cannot be determined from the environment variables. See the nested exception for details." + ) from ex if num_threads is None: if num_procs > 1 and "OMP_NUM_THREADS" not in os.environ: @@ -371,9 +388,10 @@ def init_default_process_group( if device is None: device = determine_default_device() - assert device.type == "cpu" or device.type == "cuda" + if device.type != "cpu" and device.type != "cuda": + raise InternalError(f"`device` is `{device}`.") - backend: Optional[str] + backend: str | None if device.type == "cpu": backend = Backend.GLOO @@ -385,36 +403,34 @@ def init_default_process_group( ) if device.type == "cuda": - nccl_env_name = "NCCL_ASYNC_ERROR_HANDLING" - - if torch_greater_or_equal(2, 2): - try: - del os.environ[nccl_env_name] # Suppress the deprecation warning. - except KeyError: - pass - - nccl_env_name = "TORCH_NCCL_ASYNC_ERROR_HANDLING" - # See https://github.com/pytorch/pytorch/issues/46874. - os.environ[nccl_env_name] = "1" + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" if timeout is None: timeout = timedelta(minutes=15) - dist.init_process_group(backend, timeout=timeout) + try: + dist.init_process_group(backend, timeout=timeout) + except (RuntimeError, ValueError) as ex: + raise GangError( + "The underlying process group has failed to initialize. See the nested exception for details." + ) from ex pg = dist.group.WORLD if pg is None: - raise RuntimeError( - "The default process group is not available. Please file a bug report." - ) + raise InternalError("`dist.group.WORLD` is `None`.") if monitored: if backend == Backend.GLOO: monitor_pg = pg else: # Gloo is needed for monitored barrier support. - monitor_pg = dist.new_group(backend=Backend.GLOO, timeout=timeout) + try: + monitor_pg = dist.new_group(backend=Backend.GLOO, timeout=timeout) + except RuntimeError as ex: + raise GangError( + "The underlying process group used for monitoring has failed to initialize. See the nested exception for details." + ) from ex else: monitor_pg = None @@ -437,35 +453,39 @@ def from_process_group(pg: ProcessGroup, device: Device) -> ProcessGroupGang: def from_default_process_group(cls) -> ProcessGroupGang: """Wrap the default process group as a gang.""" if not dist.is_available(): - raise RuntimeError("`torch.distributed` is not available.") + raise GangError("`torch.distributed` is not available.") if not dist.is_initialized(): - raise RuntimeError("The default process group is not initialized.") + raise GangError("The default process group is not initialized.") if cls._default is not None: return cls._default - backend = dist.get_backend() - - if backend == Backend.GLOO: - device = CPU - elif backend == Backend.NCCL: - cuda_device = determine_default_cuda_device() - if cuda_device is None: - raise RuntimeError( - "The default process group uses the `nccl` backend, but the `cuda` device cannot be determined. Please file a bug report." + try: + backend = dist.get_backend() + except RuntimeError as ex: + raise GangError( + "The default process group backend cannot be determined. See the nested exception for details." + ) from ex + + match backend: + case Backend.GLOO: + device = CPU + case Backend.NCCL: + cuda_device = determine_default_cuda_device() + if cuda_device is None: + raise GangError( + "The default process group uses the `nccl` backend, but the `cuda` device cannot be determined." + ) + + device = cuda_device + case _: + raise NotSupportedError( + f"Only `nccl` and `gloo` backends are supported, but the process group uses the `{backend}` backend." ) - device = cuda_device - else: - raise RuntimeError( - f"Only `nccl` and `gloo` backends are supported, but the process group uses the `{backend}` backend." - ) - if dist.group.WORLD is None: - raise RuntimeError( - "The default process group is not available. Please file a bug report." - ) + raise InternalError("`dist.group.WORLD` is `None`.") cls._default = ProcessGroupGang(dist.group.WORLD, device) @@ -476,15 +496,27 @@ def close(self) -> None: dist.destroy_process_group(self._pg) @override - def _do_create_gang(self, ranks: Sequence[int]) -> Optional[ProcessGroupGang]: + def _do_make_gang(self, ranks: Sequence[int]) -> ProcessGroupGang | None: if self._pg is not dist.group.WORLD: - raise RuntimeError( - "`create_gang()` can only be called on the gang associated with the default (i.e. main) process group." + raise InvalidOperationError( + "`make_gang()` can only be called on the gang associated with the default (i.e. main) process group." ) - backend = dist.get_backend() + try: + backend = dist.get_backend() + except RuntimeError as ex: + raise GangError( + "The default process group backend cannot be determined. See the nested exception for details." + ) from ex + + try: + pg = dist.new_group(ranks, backend=backend) + except RuntimeError as ex: + s = ", ".join(sorted(str(r) for r in ranks)) - pg = dist.new_group(ranks, backend=backend) + raise GangError( + f"The creation of a new child process group has failed for ranks {s}. See the nested exception for details." + ) from ex if self._rank not in ranks: return None @@ -493,7 +525,14 @@ def _do_create_gang(self, ranks: Sequence[int]) -> Optional[ProcessGroupGang]: if backend == Backend.GLOO: monitor_pg = pg else: - monitor_pg = dist.new_group(ranks, backend=Backend.GLOO) + try: + monitor_pg = dist.new_group(ranks, backend=Backend.GLOO) + except RuntimeError as ex: + s = ", ".join(sorted(str(r) for r in ranks)) + + raise GangError( + f"The creation of a new monitoring child process group has failed for ranks {s}. See the nested exception for details." + ) from ex else: monitor_pg = None @@ -506,43 +545,78 @@ def as_process_group(self) -> ProcessGroup: @override def barrier(self) -> None: if self._monitor_pg is None: - dist.barrier(group=self._pg, device_ids=[self._device.index]) + try: + dist.barrier(group=self._pg, device_ids=[self._device.index]) + except RuntimeError as ex: + raise GangError( + "The `barrier` collective operation has failed. See the nested exception for details." + ) from ex else: torch.cuda.synchronize() - dist.monitored_barrier(group=self._monitor_pg, wait_all_ranks=True) + try: + dist.monitored_barrier(group=self._monitor_pg, wait_all_ranks=True) + except RuntimeError as ex: + raise GangError( + "The `monitored_barrier` collective operation has failed. See the nested exception for details." + ) from ex @override def all_reduce(self, tensor: Tensor, op: ReduceOperation) -> None: self._maybe_monitored_barrier() - dist.all_reduce(tensor, self._get_reduce_op(op), group=self._pg) + try: + dist.all_reduce(tensor, self._get_reduce_op(op), group=self._pg) + except RuntimeError as ex: + raise GangError( + "The `all_reduce` collective operation has failed. See the nested exception for details." + ) from ex @override def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: self._maybe_monitored_barrier() - dist.all_gather_into_tensor(output_tensor, input_tensor, group=self._pg) + try: + dist.all_gather_into_tensor(output_tensor, input_tensor, group=self._pg) + except RuntimeError as ex: + raise GangError( + "The `all_gather` collective operation has failed. See the nested exception for details." + ) from ex @override def all_gather_to_list( - self, output_tensors: List[Tensor], input_tensor: Tensor + self, output_tensors: list[Tensor], input_tensor: Tensor ) -> None: self._maybe_monitored_barrier() - dist.all_gather(output_tensors, input_tensor, group=self._pg) + try: + dist.all_gather(output_tensors, input_tensor, group=self._pg) + except RuntimeError as ex: + raise GangError( + "The `all_gather_to_list` collective operation has failed. See the nested exception for details." + ) from ex @override def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: self._maybe_monitored_barrier() - dist.broadcast(tensor, source_rank, group=self._pg) + try: + dist.broadcast(tensor, source_rank, group=self._pg) + except RuntimeError as ex: + raise GangError( + "The `broadcast` collective operation has failed. See the nested exception for details." + ) from ex @override - def broadcast_objects(self, objects: List[Any], source_rank: int = 0) -> None: + def broadcast_objects(self, objects: list[object], source_rank: int = 0) -> None: self._maybe_monitored_barrier() - dist.broadcast_object_list(objects, source_rank, group=self._pg) + try: + dist.broadcast_object_list(objects, source_rank, group=self._pg) + except RuntimeError as ex: + raise GangError( + "The `broadcast_object_list` collective operation has failed. See the nested exception for details." + ) from ex def _maybe_monitored_barrier(self) -> None: if self._monitor_pg is None: @@ -550,7 +624,12 @@ def _maybe_monitored_barrier(self) -> None: torch.cuda.synchronize() - dist.monitored_barrier(group=self._monitor_pg, wait_all_ranks=True) + try: + dist.monitored_barrier(group=self._monitor_pg, wait_all_ranks=True) + except RuntimeError as ex: + raise GangError( + "The `monitored_barrier` collective operation has failed. See the nested exception for details." + ) from ex @staticmethod def _get_reduce_op(op: ReduceOperation): # type: ignore[no-untyped-def] @@ -565,8 +644,8 @@ def _get_reduce_op(op: ReduceOperation): # type: ignore[no-untyped-def] if op == ReduceOperation.MAX: return ReduceOp.MAX - raise ValueError( - f"`op` must be an operation supported by the underlying process group, but is `{op}` instead." + raise NotSupportedError( + f"`{op}` operation is not supported by the underlying process group." ) @@ -584,41 +663,13 @@ def _get_num_cpus(num_procs: int) -> int: return min(max(num_cpus // num_procs, 1), len(affinity_mask)) -def get_world_size() -> int: - """Return the world size of the running job.""" - value = get_int_from_env("WORLD_SIZE") - - return 1 if value is None else value - - -def get_rank() -> int: - """Return the rank of this process in the running job.""" - value = get_int_from_env("RANK", allow_zero=True) - - return 0 if value is None else value - - -def get_local_world_size() -> int: - """Return the local world size of the running job.""" - value = get_int_from_env("LOCAL_WORLD_SIZE") - - return 1 if value is None else value - - -def get_local_rank() -> int: - """Return the local rank of this process in the running job.""" - value = get_int_from_env("LOCAL_RANK", allow_zero=True) - - return 0 if value is None else value - - def setup_default_gang( *, - device: Optional[Device] = None, - timeout: Optional[timedelta] = None, + device: Device | None = None, + timeout: timedelta | None = None, monitored: bool = False, ) -> Gang: - """Set up the default gang of this process. + """Make the default gang of this process. :param device: If ``None``; if CUDA is available, the gang will use the default CUDA @@ -628,7 +679,14 @@ def setup_default_gang( :param monitored: If ``True``, puts a monitored barrier before every collective call. """ - if get_world_size() == 1: + try: + world_size = get_world_size() + except InvalidEnvironmentVariableError as ex: + raise GangError( + "The world size cannot be determined from the environment variables. See the nested exception for details." + ) from ex + + if world_size == 1: return FakeGang(device=device) return ProcessGroupGang.init_default_process_group( @@ -636,11 +694,172 @@ def setup_default_gang( ) -def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Dict[str, Gang]: - """Set up gangs to be used for data and tensor parallelism. +@dataclass +class Gangs: + root: Gang + dp: Gang + tp: Gang + + +def fake_gangs(device: Device) -> Gangs: + gang = FakeGang(device=device) + + return Gangs(gang, gang, gang) + + +def _setup_2D_mesh_gangs( + root_gang: Gang, + *, + row_length: int = 1, + create_single_rank_process_groups: bool = False, + dim_descriptions: list[str] | None = None, +) -> dict[int, Gang]: + """Set up gangs for this process as defined by a 2D device mesh. + + The two returned gangs are defined by the process' position in the mesh. + First gang is the row in the mesh, second is the column. + For example, assuming 8 devices denoted by g0 to g7, calling this function + with ``row_length`` = 4 amounts to defining the 2D mesh + [[g0, g1, g2, g3], [g4, g5, g6, g7]] and making 2 sets of gangs: + + 2 gangs of size 4 (mesh rows): + [g0, g1, g2, g3], [g4, g5, g6, g7] + 4 gangs of size 2 (mesh columns): + [g0, g4], [g1, g5], [g2, g6], [g3, g7] + + For the process of rank 5, the function would return the 2 sub-gangs + {0: [g4, g5, g6, g7], 1: [g1, g5]}. If adjacent ranks are on the same host + (for example, 2 hosts: one with g0 to g3, and the other with g4 to g7), + the first gang can be used to maximize local intra-host communication. + + Example use-cases include making tensor- and data- parallel gangs, or + sharding and replicating gangs in FSDP's hybrid sharding. + + :param root_gang: + The gang whose topology will be used to make the new gangs. + :param row_length: + The size of the gangs corresponding to the 2D mesh rows. + :param create_single_rank_process_groups: + If ``True``, create an underlying ``dist.ProcessGroup`` even for single-rank gangs. + The gang is faked otherwise. + :param dim_descriptions: + String descriptions of returned gangs, used in log and error messages. + + :returns: + A ``dict`` of two gangs; 0 maps to the gang of 2D mesh row, + 1 maps to the gang of the 2D mesh column. + """ + row_count = root_gang.size // row_length + + mesh = torch.arange(root_gang.size).view(row_count, row_length) + + # Get the coordinate of this process in the mesh. + rank_coords = [x.item() for x in torch.where(mesh == root_gang.rank)] + mesh_shape = mesh.size() + + output = {} + + log.info( + "Initializing sub-gangs for a 2D device mesh of shape {}.", list(mesh_shape) + ) + if dim_descriptions is None: + dim_descriptions = [f"dim-{dim}" for dim in range(2)] + + for dim in range(2): + current_subgang: Gang | None = None + + gang_size = mesh_shape[1 - dim] + + log.info( + "Initializing {} gang with a size of {}.", dim_descriptions[dim], gang_size + ) + + # Match row length (dim 0) or column length (dim 1) + match gang_size: + case 1: + if create_single_rank_process_groups: + current_subgang = root_gang.make_gang([root_gang.rank]) + else: + current_subgang = FakeGang(device=root_gang.device) + case root_gang.size: + current_subgang = root_gang + case _: + # Create 1 gang per row (dim 0) or per column (dim 1) + for i in range(mesh_shape[dim]): + ranks = mesh[i, :] if dim == 0 else mesh[:, i] + sub_gang = root_gang.make_gang(ranks.tolist()) + if i == rank_coords[dim]: + current_subgang = sub_gang + + if current_subgang is None: + raise InternalError(f"`current_gang` ({dim_descriptions[dim]}) is `None`.") + + output[dim] = current_subgang + + return output + + +def setup_hybrid_fsdp_gangs(gang: Gang, local_world_size: int) -> tuple[Gang, Gang]: + """Make gangs to be used for hybrid-sharding FSDP. + + For instance; if we have 8 devices denoted by g0 to g7 and ``local_world_size`` + is 4, this function will make 2 sharding gangs and 4 replication gangs: + + 2 sharding gangs of size 4: + [g0, g1, g2, g3], [g4, g5, g6, g7] + 4 replication gangs of size 2: + [g0, g4], [g1, g5], [g2, g6], [g3, g7] + + For efficiency, the caller should make sure adjacent ranks are on the same + host. + + :param gang: + The gang over which to shard and replicate. + :param local_world_size: + ``gang`` will be split into sub-gangs each containing + ``local_world_size`` number of consecutive processes. + The model will be fully sharded within each sub-gang and + will be replicated across sub-gangs. + + :returns: + A pair of two gangs: the sharding gang that the current process is + part of, and the replication gang that the current process is part of + """ + if local_world_size < 1: + raise ValueError( + f"`local_world_size` must be greater than 1, but is {local_world_size} instead." + ) + + if local_world_size == 1: + raise GangError( + f"`local_world_size` must be greater than 1, but is {local_world_size} instead. This hybrid configuration would force FSDP to switch to use `NO_SHARD`, which is deprecated. Please use DDP instead." + ) + + if local_world_size > gang.size: + raise ValueError( + f"`local_world_size` must be less than or equal to `gang.size` ({gang.size}), but is {local_world_size} instead." + ) + + if gang.size % local_world_size != 0: + raise GangError( + f"`gang.size` ({gang.size}) must be a multiple of `local_world_size` ({local_world_size})." + ) + + sub_gangs = _setup_2D_mesh_gangs( + gang, + row_length=local_world_size, + create_single_rank_process_groups=True, + dim_descriptions=["sharding", "replication"], + ) + + return sub_gangs[0], sub_gangs[1] + + +def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Gangs: + """Make gangs to be used for data and tensor parallelism. For instance; if we have 8 devices denoted by g0 to g7 and 2 devices are - used for tensor parallelism, this function will create 4 tensor parallel + used for tensor parallelism, this function will make 4 tensor parallel gangs and 2 data parallel gangs as: 4 tensor parallel gangs: @@ -653,63 +872,30 @@ def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Dict[str, Gang to 7 belong to the first host and ranks 8 to 15 belong to the second host. :param root_gang: - The gang whose topology will be used to create the new gangs. + The gang whose topology will be used to make the new gangs. :param tp_size: The size of tensor parallel gangs. :returns: - A ``dict`` of two gangs; (1) the data parallel gang that this process - is part of denoted by the key "dp", (2) the tensor parallel gang that - this process is part of denoted by the key "tp". + Three gangs: the root gang, the data parallel gang that this + process is part of, and the tensor parallel gang that this process is + part of. """ if tp_size <= 0: raise ValueError(f"`tp_size` must be greater than 0, but is {tp_size} instead.") if root_gang.size % tp_size != 0: - raise ValueError( - f"`root_gang.size` ({root_gang.size}) must be divisible by `tp_size` ({tp_size})." + raise GangError( + f"The number of processes in the root gang is expected to be a multiple of the tensor parallel size ({tp_size}), but is {root_gang.size} instead." ) - dp_size = root_gang.size // tp_size - - if log.is_enabled_for_info(): - for name, size in [("data", dp_size), ("tensor", tp_size)]: - log.info("Initializing {} parallelism with a gang of size {}.", name, size) - - mesh = torch.arange(root_gang.size).view(dp_size, tp_size) - - # Get the coordinate of this process in the mesh. - rank_coords = [x.item() for x in torch.where(mesh == root_gang.rank)] - - dp_gang: Optional[Gang] = None - tp_gang: Optional[Gang] = None - - # Build the gangs for data parallelism. - if dp_size == 1: - dp_gang = FakeGang(device=root_gang.device) - elif dp_size == root_gang.size: - dp_gang = root_gang - else: - for i in range(tp_size): - sub_gang = root_gang.create_gang(mesh[:, i].tolist()) - if i == rank_coords[1]: - dp_gang = sub_gang - - # Build the gangs for tensor parallelism. - if tp_size == 1: - tp_gang = FakeGang(device=root_gang.device) - elif tp_size == root_gang.size: - tp_gang = root_gang - else: - for i in range(dp_size): - sub_gang = root_gang.create_gang(mesh[i, :].tolist()) - if i == rank_coords[0]: - tp_gang = sub_gang - - assert dp_gang is not None - assert tp_gang is not None + output_from_2D_mesh = _setup_2D_mesh_gangs( + root_gang, + row_length=tp_size, + dim_descriptions=["tensor parallel", "data parallel"], + ) - return {"root": root_gang, "dp": dp_gang, "tp": tp_gang} + return Gangs(root_gang, output_from_2D_mesh[1], output_from_2D_mesh[0]) def broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) -> bool: @@ -721,7 +907,7 @@ def broadcast_flag(gang: Gang, flag: bool, source_rank: int = 0) -> bool: return bool(tmp) -def all_sum(gang: Gang, value: Union[float, int, Tensor]) -> Tensor: +def all_sum(gang: Gang, value: float | int | Tensor) -> Tensor: """Sum ``value`` over all processes in ``gang``.""" if isinstance(value, Tensor): output = value @@ -731,3 +917,36 @@ def all_sum(gang: Gang, value: Union[float, int, Tensor]) -> Tensor: gang.all_reduce(output, ReduceOperation.SUM) return output + + +def get_world_size() -> int: + """Return the world size of the running job.""" + value = get_int_from_env("WORLD_SIZE") + + return 1 if value is None else value + + +def get_rank() -> int: + """Return the rank of this process in the running job.""" + value = get_int_from_env("RANK", allow_zero=True) + + return 0 if value is None else value + + +def get_local_world_size() -> int: + """Return the local world size of the running job.""" + value = get_int_from_env("LOCAL_WORLD_SIZE") + + return 1 if value is None else value + + +def get_local_rank() -> int: + """Return the local rank of this process in the running job.""" + value = get_int_from_env("LOCAL_RANK", allow_zero=True) + + return 0 if value is None else value + + +def is_torchrun() -> bool: + """Return ``True`` if this process was spawned by torchrun.""" + return "TORCHELASTIC_RUN_ID" in os.environ diff --git a/src/fairseq2/generation/__init__.py b/src/fairseq2/generation/__init__.py index 5f098bb44..9285b53d9 100644 --- a/src/fairseq2/generation/__init__.py +++ b/src/fairseq2/generation/__init__.py @@ -6,20 +6,45 @@ from __future__ import annotations -from fairseq2.generation.beam_search import BeamSearchAlgorithm as BeamSearchAlgorithm -from fairseq2.generation.beam_search import ( +from fairseq2.generation.beam_search.algo import ( + STANDARD_BEAM_SEARCH_ALGO as STANDARD_BEAM_SEARCH_ALGO, +) +from fairseq2.generation.beam_search.algo import ( + BeamSearchAlgorithm as BeamSearchAlgorithm, +) +from fairseq2.generation.beam_search.algo import ( + BeamSearchAlgorithmHandler as BeamSearchAlgorithmHandler, +) +from fairseq2.generation.beam_search.algo import ( + BeamSearchAlgorithmNotFoundError as BeamSearchAlgorithmNotFoundError, +) +from fairseq2.generation.beam_search.algo import BeamStep as BeamStep +from fairseq2.generation.beam_search.algo import ( + StandardBeamSearchAlgorithm as StandardBeamSearchAlgorithm, +) +from fairseq2.generation.beam_search.algo import ( + StandardBeamSearchAlgorithmHandler as StandardBeamSearchAlgorithmHandler, +) +from fairseq2.generation.beam_search.generator import ( BeamSearchSeq2SeqGenerator as BeamSearchSeq2SeqGenerator, ) -from fairseq2.generation.beam_search import ( +from fairseq2.generation.beam_search.generator import ( BeamSearchSequenceGenerator as BeamSearchSequenceGenerator, ) -from fairseq2.generation.beam_search import ( - StandardBeamSearchAlgorithm as StandardBeamSearchAlgorithm, +from fairseq2.generation.beam_search.handler import ( + BEAM_SEARCH_GENERATOR as BEAM_SEARCH_GENERATOR, +) +from fairseq2.generation.beam_search.handler import AlgorithmSection as AlgorithmSection +from fairseq2.generation.beam_search.handler import ( + AlgorithmSectionHandler as AlgorithmSectionHandler, +) +from fairseq2.generation.beam_search.handler import BeamSearchConfig as BeamSearchConfig +from fairseq2.generation.beam_search.handler import ( + BeamSearchSeq2SeqGeneratorHandler as BeamSearchSeq2SeqGeneratorHandler, +) +from fairseq2.generation.beam_search.handler import ( + BeamSearchSequenceGeneratorHandler as BeamSearchSequenceGeneratorHandler, ) -from fairseq2.generation.chatbot import AbstractChatbot as AbstractChatbot -from fairseq2.generation.chatbot import Chatbot as Chatbot -from fairseq2.generation.chatbot import ChatDialog as ChatDialog -from fairseq2.generation.chatbot import ChatMessage as ChatMessage from fairseq2.generation.generator import ( AbstractSeq2SeqGenerator as AbstractSeq2SeqGenerator, ) @@ -36,15 +61,59 @@ SequenceGeneratorOutput as SequenceGeneratorOutput, ) from fairseq2.generation.generator import StepHook as StepHook -from fairseq2.generation.sampling import Sampler as Sampler -from fairseq2.generation.sampling import ( +from fairseq2.generation.handler import ( + Seq2SeqGeneratorHandler as Seq2SeqGeneratorHandler, +) +from fairseq2.generation.handler import ( + Seq2SeqGeneratorNotFoundError as Seq2SeqGeneratorNotFoundError, +) +from fairseq2.generation.handler import ( + SequenceGeneratorHandler as SequenceGeneratorHandler, +) +from fairseq2.generation.handler import ( + SequenceGeneratorNotFoundError as SequenceGeneratorNotFoundError, +) +from fairseq2.generation.sampling.generator import ( SamplingSeq2SeqGenerator as SamplingSeq2SeqGenerator, ) -from fairseq2.generation.sampling import ( +from fairseq2.generation.sampling.generator import ( SamplingSequenceGenerator as SamplingSequenceGenerator, ) -from fairseq2.generation.sampling import TopKSampler as TopKSampler -from fairseq2.generation.sampling import TopPSampler as TopPSampler +from fairseq2.generation.sampling.handler import ( + SAMPLING_GENERATOR as SAMPLING_GENERATOR, +) +from fairseq2.generation.sampling.handler import SamplerSection as SamplerSection +from fairseq2.generation.sampling.handler import ( + SamplerSectionHandler as SamplerSectionHandler, +) +from fairseq2.generation.sampling.handler import SamplingConfig as SamplingConfig +from fairseq2.generation.sampling.handler import ( + SamplingSeq2SeqGeneratorHandler as SamplingSeq2SeqGeneratorHandler, +) +from fairseq2.generation.sampling.handler import ( + SamplingSequenceGeneratorHandler as SamplingSequenceGeneratorHandler, +) +from fairseq2.generation.sampling.sampler import TOP_K_SAMPLER as TOP_K_SAMPLER +from fairseq2.generation.sampling.sampler import TOP_P_SAMPLER as TOP_P_SAMPLER +from fairseq2.generation.sampling.sampler import Sampler as Sampler +from fairseq2.generation.sampling.sampler import SamplerHandler as SamplerHandler +from fairseq2.generation.sampling.sampler import ( + SamplerNotFoundError as SamplerNotFoundError, +) +from fairseq2.generation.sampling.sampler import TopKSampler as TopKSampler +from fairseq2.generation.sampling.sampler import TopKSamplerConfig as TopKSamplerConfig +from fairseq2.generation.sampling.sampler import ( + TopKSamplerHandler as TopKSamplerHandler, +) +from fairseq2.generation.sampling.sampler import TopPSampler as TopPSampler +from fairseq2.generation.sampling.sampler import TopPSamplerConfig as TopPSamplerConfig +from fairseq2.generation.sampling.sampler import ( + TopPSamplerHandler as TopPSamplerHandler, +) +from fairseq2.generation.static import ( + create_seq2seq_generator as create_seq2seq_generator, +) +from fairseq2.generation.static import create_seq_generator as create_seq_generator from fairseq2.generation.step_processor import ( BannedSequenceProcessor as BannedSequenceProcessor, ) diff --git a/src/fairseq2/generation/beam_search/__init__.py b/src/fairseq2/generation/beam_search/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/fairseq2/generation/beam_search/algo.py b/src/fairseq2/generation/beam_search/algo.py new file mode 100644 index 000000000..4ea31674c --- /dev/null +++ b/src/fairseq2/generation/beam_search/algo.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from types import NoneType +from typing import Final, final + +import torch +from torch import Tensor +from typing_extensions import override + + +class BeamSearchAlgorithm(ABC): + """Represents a beam search algorithm.""" + + @abstractmethod + def step(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep: + """Take a single step. + + A subclass implementation is expected to return the best 2 x `beam_size` + candidates. The sequence generator will choose the first `beam_size` of + these which don't predict EOS to continue with. + + :param beam_size: + The beam size. + :param lprobs: + The next-step log probability of each vocabulary entry. *Shape:* + :math:`(N,V)`, where :math:`N` is the batch size and :math:`V` is + the size of the vocabulary. + :param step_scores: + The cumulative score of each step in the beam. *Shape:* :math:`(N,S)`, + where :math:`N` is the batch size and :math:`S` is the length of the + beam. + """ + + +@final +class StandardBeamSearchAlgorithm(BeamSearchAlgorithm): + """Represents a standard beam search algoritm.""" + + @override + def step(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep: + vocab_size = lprobs.size(1) + + # Make the probabilities contain cumulative scores for each hypothesis. + # (N, V) + (N, 1) = (N, V) + lprobs = lprobs + step_scores[:, -1].unsqueeze(-1) + + # (N, V) -> (N x V) + lprobs = lprobs.view(-1) + + # (2 x B) + top_scores, top_indices = torch.topk(lprobs, k=min(2 * beam_size, vocab_size)) + + return BeamStep(top_indices // vocab_size, top_indices % vocab_size, top_scores) + + +@final +@dataclass +class BeamStep: + """Represents the output of a beam search algorithm.""" + + seq_indices: Tensor + """The beam sequence indices. *Shape:* :math:`(B)`, where :math:`B` is the + beam size.""" + + vocab_indices: Tensor + """The vocabulary indices. *Shape:* Same as ``seq_indices``.""" + + scores: Tensor + """The scores. *Shape:* Same as ``seq_indices``.""" + + def masked_select(self, mask: Tensor) -> BeamStep: + """Reduce the beam to the sequences included in ``mask``.""" + seq_indices = self.seq_indices.masked_select(mask) + + vocab_indices = self.vocab_indices.masked_select(mask) + + scores = self.scores.masked_select(mask) + + return BeamStep(seq_indices, vocab_indices, scores) + + def first(self, count: int) -> BeamStep: + """Slice the beam to the first ``count`` sequences.""" + seq_indices = self.seq_indices[:count] + + vocab_indices = self.vocab_indices[:count] + + scores = self.scores[:count] + + return BeamStep(seq_indices, vocab_indices, scores) + + @staticmethod + def merge(steps: Sequence[BeamStep]) -> BeamStep: + """Merge ``steps`` into a single beam.""" + seq_indices = torch.cat([s.seq_indices for s in steps]) + + vocab_indices = torch.cat([s.vocab_indices for s in steps]) + + scores = torch.cat([s.scores for s in steps]) + + return BeamStep(seq_indices, vocab_indices, scores) + + +class BeamSearchAlgorithmHandler(ABC): + @abstractmethod + def create(self, config: object) -> BeamSearchAlgorithm: + ... + + @property + @abstractmethod + def config_kls(self) -> type: + ... + + +class BeamSearchAlgorithmNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known beam search algorithm.") + + self.name = name + + +STANDARD_BEAM_SEARCH_ALGO: Final = "standard" + + +@final +class StandardBeamSearchAlgorithmHandler(BeamSearchAlgorithmHandler): + @override + def create(self, config: object) -> BeamSearchAlgorithm: + if config is not None: + raise ValueError( + "`config` must not be specified for standard beam-search algorithm." + ) + + return StandardBeamSearchAlgorithm() + + @property + @override + def config_kls(self) -> type: + return NoneType diff --git a/src/fairseq2/generation/beam_search.py b/src/fairseq2/generation/beam_search/generator.py similarity index 84% rename from src/fairseq2/generation/beam_search.py rename to src/fairseq2/generation/beam_search/generator.py index 3c017bb32..5c2a32732 100644 --- a/src/fairseq2/generation/beam_search.py +++ b/src/fairseq2/generation/beam_search/generator.py @@ -7,20 +7,28 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple, Union, final +from collections.abc import Sequence +from typing import final import torch from torch import Tensor from torch.nn.functional import log_softmax +from typing_extensions import override from fairseq2.data import VocabularyInfo +from fairseq2.error import InternalError +from fairseq2.generation.beam_search.algo import ( + BeamSearchAlgorithm, + BeamStep, + StandardBeamSearchAlgorithm, +) from fairseq2.generation.generator import ( AbstractSeq2SeqGenerator, AbstractSequenceGenerator, GenerationCounters, Hypothesis, Seq2SeqGeneratorOutput, + SequenceGenerationError, SequenceGeneratorOutput, StepHook, ) @@ -30,7 +38,6 @@ from fairseq2.models.sequence import SequenceModelOutput from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import override from fairseq2.utils.profiler import Stopwatch @@ -48,27 +55,27 @@ class BeamSearchSequenceGenerator(AbstractSequenceGenerator): _temperature: float _unk_penalty: float _len_penalty: float - _prefill_chunk_size: Optional[int] - _decode_capacity_increment: Optional[int] + _prefill_chunk_size: int | None + _decode_capacity_increment: int | None _step_processors: Sequence[StepProcessor] def __init__( self, model: DecoderModel, *, - algorithm: Optional[BeamSearchAlgorithm] = None, + algorithm: BeamSearchAlgorithm | None = None, beam_size: int = 5, min_gen_len: int = 1, max_gen_len: int = 128, - max_seq_len: Optional[int] = None, + max_seq_len: int | None = None, echo_prompt: bool = False, normalize_scores: bool = True, temperature: float = 1.0, unk_penalty: float = 0.0, len_penalty: float = 1.0, - prefill_chunk_size: Optional[int] = 512, - decode_capacity_increment: Optional[int] = 16, - step_processors: Optional[Sequence[StepProcessor]] = None, + prefill_chunk_size: int | None = 512, + decode_capacity_increment: int | None = 16, + step_processors: Sequence[StepProcessor] | None = None, ) -> None: """ :param model: @@ -159,7 +166,7 @@ def __init__( @torch.inference_mode() @override def __call__( - self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask] + self, prompt_seqs: Tensor, prompt_padding_mask: PaddingMask | None ) -> SequenceGeneratorOutput: op = _BeamSearchSequenceGeneratorOp( self._model, @@ -193,34 +200,34 @@ class BeamSearchSeq2SeqGenerator(AbstractSeq2SeqGenerator): _algorithm: BeamSearchAlgorithm _beam_size: int _min_gen_len: int - _max_gen_len: Tuple[int, int] + _max_gen_len: tuple[int, int] _max_seq_len: int _echo_prompt: bool _normalize_scores: bool _temperature: float _unk_penalty: float _len_penalty: float - _prefill_chunk_size: Optional[int] - _decode_capacity_increment: Optional[int] + _prefill_chunk_size: int | None + _decode_capacity_increment: int | None _step_processors: Sequence[StepProcessor] def __init__( self, model: EncoderDecoderModel, *, - algorithm: Optional[BeamSearchAlgorithm] = None, + algorithm: BeamSearchAlgorithm | None = None, beam_size: int = 5, min_gen_len: int = 1, - max_gen_len: Tuple[int, int] = (1, 128), - max_seq_len: Optional[int] = None, + max_gen_len: tuple[int, int] = (1, 128), + max_seq_len: int | None = None, echo_prompt: bool = False, normalize_scores: bool = True, temperature: float = 1.0, unk_penalty: float = 0.0, len_penalty: float = 1.0, - prefill_chunk_size: Optional[int] = 512, - decode_capacity_increment: Optional[int] = 16, - step_processors: Optional[Sequence[StepProcessor]] = None, + prefill_chunk_size: int | None = 512, + decode_capacity_increment: int | None = 16, + step_processors: Sequence[StepProcessor] | None = None, ) -> None: """ :param model: @@ -304,9 +311,9 @@ def __init__( def __call__( self, source_seqs: Tensor, - source_padding_mask: Optional[PaddingMask], + source_padding_mask: PaddingMask | None, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, ) -> Seq2SeqGeneratorOutput: # (P, S) encoder_output, encoder_padding_mask = self.model.encode( @@ -363,103 +370,11 @@ def __call__( ) -class BeamSearchAlgorithm(ABC): - """Represents a beam search algorithm.""" - - @abstractmethod - def __call__(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep: - """Take a single step. - - A subclass implementation is expected to return the best 2 x `beam_size` - candidates. The sequence generator will choose the first `beam_size` of - these which don't predict EOS to continue with. - - :param beam_size: - The beam size. - :param lprobs: - The next-step log probability of each vocabulary entry. *Shape:* - :math:`(N,V)`, where :math:`N` is the batch size and :math:`V` is - the size of the vocabulary. - :param step_scores: - The cumulative score of each step in the beam. *Shape:* :math:`(N,S)`, - where :math:`N` is the batch size and :math:`S` is the length of the - beam. - """ - - -@final -@dataclass -class BeamStep: - """Represents the output of a beam search algorithm.""" - - seq_indices: Tensor - """The beam sequence indices. *Shape:* :math:`(B)`, where :math:`B` is the - beam size.""" - - vocab_indices: Tensor - """The vocabulary indices. *Shape:* Same as ``seq_indices``.""" - - scores: Tensor - """The scores. *Shape:* Same as ``seq_indices``.""" - - def masked_select(self, mask: Tensor) -> BeamStep: - """Reduce the beam to the sequences included in ``mask``.""" - seq_indices = self.seq_indices.masked_select(mask) - - vocab_indices = self.vocab_indices.masked_select(mask) - - scores = self.scores.masked_select(mask) - - return BeamStep(seq_indices, vocab_indices, scores) - - def first(self, count: int) -> BeamStep: - """Slice the beam to the first ``count`` sequences.""" - seq_indices = self.seq_indices[:count] - - vocab_indices = self.vocab_indices[:count] - - scores = self.scores[:count] - - return BeamStep(seq_indices, vocab_indices, scores) - - @staticmethod - def merge(steps: Sequence[BeamStep]) -> BeamStep: - """Merge ``steps`` into a single beam.""" - seq_indices = torch.cat([s.seq_indices for s in steps]) - - vocab_indices = torch.cat([s.vocab_indices for s in steps]) - - scores = torch.cat([s.scores for s in steps]) - - return BeamStep(seq_indices, vocab_indices, scores) - - -@final -class StandardBeamSearchAlgorithm(BeamSearchAlgorithm): - """Represents a standard beam search algoritm.""" - - @override - def __call__(self, beam_size: int, lprobs: Tensor, step_scores: Tensor) -> BeamStep: - vocab_size = lprobs.size(1) - - # Make the probabilities contain cumulative scores for each hypothesis. - # (N, V) + (N, 1) = (N, V) - lprobs = lprobs + step_scores[:, -1].unsqueeze(-1) - - # (N, V) -> (N x V) - lprobs = lprobs.view(-1) - - # (2 x B) - top_scores, top_indices = torch.topk(lprobs, k=min(2 * beam_size, vocab_size)) - - return BeamStep(top_indices // vocab_size, top_indices % vocab_size, top_scores) - - class _AbstractBeamSearchSequenceGeneratorOp(ABC): _algorithm: BeamSearchAlgorithm _eos_idx: int - _pad_idx: Optional[int] - _unk_idx: Optional[int] + _pad_idx: int | None + _unk_idx: int | None _beam_size: int _min_prompt_len: int _max_prompt_len: int @@ -470,24 +385,24 @@ class _AbstractBeamSearchSequenceGeneratorOp(ABC): _temperature: float _unk_penalty: float _len_penalty: float - _prefill_chunk_size: Optional[int] + _prefill_chunk_size: int | None _step_processors: Sequence[StepProcessor] - _step_hooks: Dict[int, StepHook] + _step_hooks: dict[int, StepHook] _step_nr: int _state_bag: IncrementalStateBag - _prompt_lens: Optional[Tensor] - _prompt_mask: Optional[Tensor] - _beam_sizes: List[int] + _prompt_lens: Tensor | None + _prompt_mask: Tensor | None + _beam_sizes: list[int] _prompt_indices: Tensor _seqs: Tensor _step_scores: Tensor - _output: List[List[Hypothesis]] + _output: list[list[Hypothesis]] _counters: GenerationCounters def __init__( self, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, algorithm: BeamSearchAlgorithm, vocab_info: VocabularyInfo, beam_size: int, @@ -499,14 +414,15 @@ def __init__( temperature: float, unk_penalty: float, len_penalty: float, - prefill_chunk_size: Optional[int], - decode_capacity_increment: Optional[int], + prefill_chunk_size: int | None, + decode_capacity_increment: int | None, step_processors: Sequence[StepProcessor], - step_hooks: Dict[int, StepHook], + step_hooks: dict[int, StepHook], ) -> None: self._algorithm = algorithm - assert vocab_info.eos_idx is not None + if vocab_info.eos_idx is None: + raise InternalError("`vocab_info.eos_idx` is `None`.") self._eos_idx = vocab_info.eos_idx self._pad_idx = vocab_info.pad_idx @@ -514,8 +430,8 @@ def __init__( self._beam_size = beam_size - min_prompt_idx: Union[int, Tensor] - max_prompt_idx: Union[int, Tensor] + min_prompt_idx: int | Tensor + max_prompt_idx: int | Tensor if prompt_padding_mask is None: self._min_prompt_len, min_prompt_idx = prompt_seqs.size(1), 0 @@ -599,7 +515,7 @@ def __init__( self._counters = GenerationCounters() - def __call__(self) -> Tuple[List[List[Hypothesis]], GenerationCounters]: + def __call__(self) -> tuple[list[list[Hypothesis]], GenerationCounters]: self._prepare_state() watch = Stopwatch(start=True, device=self._seqs.device) @@ -651,7 +567,7 @@ def _prefill(self) -> None: lprobs = log_softmax(logits, dim=-1, dtype=torch.float32) if lprobs.isnan().any(): - raise RuntimeError( + raise SequenceGenerationError( "The model has produced one or more NaN probabilities during prefill. The sequence generator cannot continue." ) @@ -704,7 +620,7 @@ def _step(self) -> bool: lprobs.squeeze_(1) if lprobs.isnan().any(): - raise RuntimeError( + raise SequenceGenerationError( f"The model has produced one or more NaN probabilities at step {self._step_nr}. The sequence generator cannot continue." ) @@ -732,9 +648,9 @@ def _step(self) -> bool: batch_offset = 0 - new_beam_sizes: List[int] = [] + new_beam_sizes: list[int] = [] - beam_next_step_list: List[BeamStep] = [] + beam_next_step_list: list[BeamStep] = [] # We split the batch by `beam_sizes` and treat each beam separately. for beam_idx, (beam_lprobs, beam_step_scores) in enumerate( @@ -754,7 +670,8 @@ def _step(self) -> bool: beam_size = len(beam_next_step.seq_indices) # We should have terminated the beam if there are no sequences. - assert beam_size > 0 + if beam_size == 0: + raise InternalError("`beam_size` is zero.") new_beam_sizes.append(beam_size) @@ -789,15 +706,17 @@ def _step(self) -> bool: def _search_beam( self, beam_idx: int, batch_offset: int, lprobs: Tensor, step_scores: Tensor - ) -> Optional[BeamStep]: + ) -> BeamStep | None: # Ignore the generated indices for the prompt sequences. if self._step_nr < self._max_prompt_len: - assert self._prompt_mask is not None + if self._prompt_mask is None: + raise InternalError("`_prompt_mask` is `None`.") # Check if the current beam is in a prompt sequence. if self._prompt_mask[batch_offset, self._step_nr]: # The size of a beam in a prompt sequence must be always 1. - assert len(lprobs) == 1 + if len(lprobs) != 1: + raise InternalError(f"The length of `lprobs` is {len(lprobs)}.") seq_idx = torch.tensor([batch_offset], device=lprobs.device) @@ -816,7 +735,7 @@ def _search_beam( # best 2 x `beam_size` candidates and choose the first `beam_size` of # these which don't predict EOS to continue with. # (2 x B) - next_step = self._algorithm( + next_step = self._algorithm.step( self._beam_size, lprobs, step_scores[:, : self._step_nr] ) @@ -931,7 +850,7 @@ def __init__( self, model: DecoderModel, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, algorithm: BeamSearchAlgorithm, beam_size: int, min_gen_len: int, @@ -942,10 +861,10 @@ def __init__( temperature: float, unk_penalty: float, len_penalty: float, - prefill_chunk_size: Optional[int], - decode_capacity_increment: Optional[int], + prefill_chunk_size: int | None, + decode_capacity_increment: int | None, step_processors: Sequence[StepProcessor], - step_hooks: Dict[int, StepHook], + step_hooks: dict[int, StepHook], ) -> None: super().__init__( prompt_seqs, @@ -984,15 +903,15 @@ def _decode(self, seqs: Tensor) -> SequenceModelOutput: class _BeamSearchSeq2SeqGeneratorOp(_AbstractBeamSearchSequenceGeneratorOp): _model: EncoderDecoderModel _encoder_output: Tensor - _encoder_padding_mask: Optional[PaddingMask] + _encoder_padding_mask: PaddingMask | None def __init__( self, model: EncoderDecoderModel, encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + encoder_padding_mask: PaddingMask | None, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, algorithm: BeamSearchAlgorithm, beam_size: int, min_gen_len: int, @@ -1003,10 +922,10 @@ def __init__( temperature: float, unk_penalty: float, len_penalty: float, - prefill_chunk_size: Optional[int], - decode_capacity_increment: Optional[int], + prefill_chunk_size: int | None, + decode_capacity_increment: int | None, step_processors: Sequence[StepProcessor], - step_hooks: Dict[int, StepHook], + step_hooks: dict[int, StepHook], ) -> None: super().__init__( prompt_seqs, diff --git a/src/fairseq2/generation/beam_search/handler.py b/src/fairseq2/generation/beam_search/handler.py new file mode 100644 index 000000000..9b69be262 --- /dev/null +++ b/src/fairseq2/generation/beam_search/handler.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Final, final + +from typing_extensions import override + +from fairseq2.context import Provider +from fairseq2.generation.beam_search.algo import ( + STANDARD_BEAM_SEARCH_ALGO, + BeamSearchAlgorithmHandler, + BeamSearchAlgorithmNotFoundError, +) +from fairseq2.generation.beam_search.generator import ( + BeamSearchSeq2SeqGenerator, + BeamSearchSequenceGenerator, +) +from fairseq2.generation.generator import Seq2SeqGenerator, SequenceGenerator +from fairseq2.generation.handler import ( + Seq2SeqGeneratorHandler, + SequenceGeneratorHandler, +) +from fairseq2.models.decoder import DecoderModel +from fairseq2.models.encoder_decoder import EncoderDecoderModel +from fairseq2.typing import safe_cast +from fairseq2.utils.config import ConfigSectionHandler +from fairseq2.utils.structured import StructureError, structure + +BEAM_SEARCH_GENERATOR: Final = "beam_search" + + +@dataclass(kw_only=True) +class BeamSearchConfig: + algorithm: AlgorithmSection = field(default_factory=lambda: AlgorithmSection()) + """The beam search algorithm.""" + + beam_size: int = 5 + """The beam size.""" + + min_gen_len: int = 1 + """The minimum generation length.""" + + max_gen_len: int | tuple[int, int] = 2048 + """The maximum generation length.""" + + max_seq_len: int | None = None + """The maximum sequence length including prompt.""" + + echo_prompt: bool = False + """If ``True``, returns generated sequences with prompts appended.""" + + normalize_scores: bool = True + """If ``True``, normalizes scores by lengths of generated sequences.""" + + temperature: float = 1.0 + """The logit temperature.""" + + unk_penalty: float = 0.0 + """The UNK symbol penalty.""" + + len_penalty: float = 1.0 + """The length penalty.""" + + prefill_chunk_size: int | None = 512 + """The prefill will be performed incrementally by chunks of this size.""" + + decode_capacity_increment: int | None = 16 + """The sequence length capacity will be incremented by multiplies of this value.""" + + +@dataclass(kw_only=True) +class AlgorithmSection: + name: str = STANDARD_BEAM_SEARCH_ALGO + + config: object = None + + +@final +class AlgorithmSectionHandler(ConfigSectionHandler): + _algorithm_handlers: Provider[BeamSearchAlgorithmHandler] + + def __init__( + self, algorithm_handlers: Provider[BeamSearchAlgorithmHandler] + ) -> None: + self._algorithm_handlers = algorithm_handlers + + @override + def process(self, section: object) -> None: + section = safe_cast("section", section, AlgorithmSection) + + try: + algorithm_handler = self._algorithm_handlers.get(section.name) + except LookupError: + raise BeamSearchAlgorithmNotFoundError(section.name) from None + + try: + section.config = structure(section.config, algorithm_handler.config_kls) + except StructureError as ex: + raise StructureError( + "`config` cannot be structured. See the nested exception for details." + ) from ex + + +@final +class BeamSearchSequenceGeneratorHandler(SequenceGeneratorHandler): + _algorithm_handlers: Provider[BeamSearchAlgorithmHandler] + + def __init__( + self, algorithm_handlers: Provider[BeamSearchAlgorithmHandler] + ) -> None: + self._algorithm_handlers = algorithm_handlers + + @override + def create(self, model: DecoderModel, config: object) -> SequenceGenerator: + config = safe_cast("config", config, BeamSearchConfig) + + algorithm_section = config.algorithm + + try: + algorithm_handler = self._algorithm_handlers.get(algorithm_section.name) + except LookupError: + raise BeamSearchAlgorithmNotFoundError(algorithm_section.name) from None + + algorithm = algorithm_handler.create(algorithm_section.config) + + if isinstance(config.max_gen_len, int): + max_gen_len = config.max_gen_len + else: + if config.max_gen_len[0] != 1: + raise ValueError("`max_gen_len` must be an integer.") + + max_gen_len = config.max_gen_len[1] + + return BeamSearchSequenceGenerator( + model, + algorithm=algorithm, + beam_size=config.beam_size, + min_gen_len=config.min_gen_len, + max_gen_len=max_gen_len, + max_seq_len=config.max_seq_len, + echo_prompt=config.echo_prompt, + normalize_scores=config.normalize_scores, + temperature=config.temperature, + unk_penalty=config.unk_penalty, + len_penalty=config.len_penalty, + prefill_chunk_size=config.prefill_chunk_size, + decode_capacity_increment=config.decode_capacity_increment, + ) + + @property + @override + def config_kls(self) -> type: + return BeamSearchConfig + + +@final +class BeamSearchSeq2SeqGeneratorHandler(Seq2SeqGeneratorHandler): + _algorithm_handlers: Provider[BeamSearchAlgorithmHandler] + + def __init__( + self, algorithm_handlers: Provider[BeamSearchAlgorithmHandler] + ) -> None: + self._algorithm_handlers = algorithm_handlers + + @override + def create(self, model: EncoderDecoderModel, config: object) -> Seq2SeqGenerator: + config = safe_cast("config", config, BeamSearchConfig) + + algorithm_section = config.algorithm + + try: + algorithm_handler = self._algorithm_handlers.get(algorithm_section.name) + except LookupError: + raise BeamSearchAlgorithmNotFoundError(algorithm_section.name) from None + + algorithm = algorithm_handler.create(algorithm_section.config) + + max_gen_len = config.max_gen_len + + if isinstance(max_gen_len, int): + max_gen_len = (1, max_gen_len) + + return BeamSearchSeq2SeqGenerator( + model, + algorithm=algorithm, + beam_size=config.beam_size, + min_gen_len=config.min_gen_len, + max_gen_len=max_gen_len, + max_seq_len=config.max_seq_len, + echo_prompt=config.echo_prompt, + normalize_scores=config.normalize_scores, + temperature=config.temperature, + unk_penalty=config.unk_penalty, + len_penalty=config.len_penalty, + prefill_chunk_size=config.prefill_chunk_size, + decode_capacity_increment=config.decode_capacity_increment, + ) + + @property + @override + def config_kls(self) -> type: + return BeamSearchConfig diff --git a/src/fairseq2/generation/generator.py b/src/fairseq2/generation/generator.py index f737430df..a6c506e4c 100644 --- a/src/fairseq2/generation/generator.py +++ b/src/fairseq2/generation/generator.py @@ -9,15 +9,15 @@ from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List, Optional, Protocol, final +from typing import Protocol, final from torch import Tensor from torch.utils.hooks import RemovableHandle +from typing_extensions import override from fairseq2.models.decoder import DecoderModel from fairseq2.models.encoder_decoder import EncoderDecoderModel from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import override class SequenceGenerator(ABC): @@ -25,7 +25,7 @@ class SequenceGenerator(ABC): @abstractmethod def __call__( - self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask] + self, prompt_seqs: Tensor, prompt_padding_mask: PaddingMask | None ) -> SequenceGeneratorOutput: """ :param prompt_seqs: @@ -55,11 +55,64 @@ def model(self) -> DecoderModel: """The associated decoder model.""" +@final +@dataclass +class SequenceGeneratorOutput: + """Holds the output of a sequence generator.""" + + hypotheses: list[list[Hypothesis]] + """The list of hypothesis generated per prompt, ordered by score.""" + + counters: GenerationCounters + """The performance counters of the call.""" + + +@final +@dataclass +class Hypothesis: + """Represents a hypothesis produced by a sequence generator.""" + + seq: Tensor + """The generated sequence. *Shape:* :math:`(S)`, where :math:`S` is the + sequence length.""" + + score: Tensor | None + """The score of the hypothesis. *Shape:* Scalar.""" + + step_scores: Tensor | None + """The score of each sequence step. *Shape:* Same as ``seq``.""" + + +@final +@dataclass +class GenerationCounters: + """Holds the performance counters of a generator call.""" + + prefill_size: int = 0 + """The number of elements processed during the prefill step.""" + + num_generated_elements: int = 0 + """The number of generated elements.""" + + generation_time: float = 0 + """The generation time excluding prefill.""" + + cache_size: int = 0 + """The final size of the incremental cache in bytes.""" + + cache_capacity: int = 0 + """The final reserved capacity of the incremental cache in bytes.""" + + +class SequenceGenerationError(Exception): + pass + + class AbstractSequenceGenerator(SequenceGenerator): """Provides a skeletal implementation of :class:`SequenceGenerator`.""" _model: DecoderModel - _step_hooks: Dict[int, StepHook] + _step_hooks: dict[int, StepHook] def __init__(self, model: DecoderModel) -> None: """ @@ -93,18 +146,6 @@ def model(self) -> DecoderModel: return self._model -@final -@dataclass -class SequenceGeneratorOutput: - """Holds the output of a sequence generator.""" - - hypotheses: List[List[Hypothesis]] - """The list of hypothesis generated per prompt, ordered by score.""" - - counters: GenerationCounters - """The performance counters of the call.""" - - class Seq2SeqGenerator(ABC): """Represents a sequence-to-sequence generator.""" @@ -112,9 +153,9 @@ class Seq2SeqGenerator(ABC): def __call__( self, source_seqs: Tensor, - source_padding_mask: Optional[PaddingMask], + source_padding_mask: PaddingMask | None, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, ) -> Seq2SeqGeneratorOutput: """ :param source_seqs: @@ -151,11 +192,32 @@ def model(self) -> EncoderDecoderModel: """The associated encoder-decoder model.""" +@final +@dataclass +class Seq2SeqGeneratorOutput: + hypotheses: list[list[Hypothesis]] + """The list of hypothesis generated per prompt, ordered by score.""" + + encoder_output: Tensor + """The encoder output used in encoder-decoder attention. *Shape:* + :math:`(N,S_{enc},M)`, where :math:`N` is the batch size, :math:`S_{enc}` is + the encoder output sequence length, and :math:`M` is the dimensionality of + the model.""" + + encoder_padding_mask: PaddingMask | None + """The padding mask of :attr:`encoder_output`. *Shape:* :math:`(N,S_{enc})`, + where :math:`N` is the batch size and :math:`S_{enc}` is the encoder output + sequence length.""" + + counters: GenerationCounters + """The performance counters of the call.""" + + class AbstractSeq2SeqGenerator(Seq2SeqGenerator): """Provides a skeletal implementation of :class:`Seq2SeqGenerator`.""" _model: EncoderDecoderModel - _step_hooks: Dict[int, StepHook] + _step_hooks: dict[int, StepHook] def __init__(self, model: EncoderDecoderModel) -> None: """ @@ -189,64 +251,6 @@ def model(self) -> EncoderDecoderModel: return self._model -@final -@dataclass -class Seq2SeqGeneratorOutput: - hypotheses: List[List[Hypothesis]] - """The list of hypothesis generated per prompt, ordered by score.""" - - encoder_output: Tensor - """The encoder output used in encoder-decoder attention. *Shape:* - :math:`(N,S_{enc},M)`, where :math:`N` is the batch size, :math:`S_{enc}` is - the encoder output sequence length, and :math:`M` is the dimensionality of - the model.""" - - encoder_padding_mask: Optional[PaddingMask] - """The padding mask of :attr:`encoder_output`. *Shape:* :math:`(N,S_{enc})`, - where :math:`N` is the batch size and :math:`S_{enc}` is the encoder output - sequence length.""" - - counters: GenerationCounters - """The performance counters of the call.""" - - -@final -@dataclass -class Hypothesis: - """Represents a hypothesis produced by a sequence generator.""" - - seq: Tensor - """The generated sequence. *Shape:* :math:`(S)`, where :math:`S` is the - sequence length.""" - - score: Optional[Tensor] - """The score of the hypothesis. *Shape:* Scalar.""" - - step_scores: Optional[Tensor] - """The score of each sequence step. *Shape:* Same as ``seq``.""" - - -@final -@dataclass -class GenerationCounters: - """Holds the performance counters of a generator call.""" - - prefill_size: int = 0 - """The number of elements processed during the prefill step.""" - - num_generated_elements: int = 0 - """The number of generated elements.""" - - generation_time: float = 0 - """The generation time excluding prefill.""" - - cache_size: int = 0 - """The final size of the incremental cache in bytes.""" - - cache_capacity: int = 0 - """The final reserved capacity of the incremental cache in bytes.""" - - class StepHook(Protocol): """Represents a hook to pass to :meth:`~SequenceGenerator.register_step_hook` or :meth:`~Seq2SeqGenerator.register_step_hook`.""" @@ -255,7 +259,7 @@ def __call__( self, prompt_indices: Tensor, seqs: Tensor, - step_scores: Optional[Tensor], + step_scores: Tensor | None, prefill: bool, ) -> None: """ diff --git a/src/fairseq2/generation/handler.py b/src/fairseq2/generation/handler.py new file mode 100644 index 000000000..de3126d30 --- /dev/null +++ b/src/fairseq2/generation/handler.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from fairseq2.generation.generator import Seq2SeqGenerator, SequenceGenerator +from fairseq2.models.decoder import DecoderModel +from fairseq2.models.encoder_decoder import EncoderDecoderModel + + +class SequenceGeneratorHandler(ABC): + @abstractmethod + def create(self, model: DecoderModel, config: object) -> SequenceGenerator: + ... + + @property + @abstractmethod + def config_kls(self) -> type: + ... + + +class SequenceGeneratorNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known sequence generator.") + + self.name = name + + +class Seq2SeqGeneratorHandler(ABC): + @abstractmethod + def create(self, model: EncoderDecoderModel, config: object) -> Seq2SeqGenerator: + ... + + @property + @abstractmethod + def config_kls(self) -> type: + ... + + +class Seq2SeqGeneratorNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known sequence-to-sequence generator.") + + self.name = name diff --git a/src/fairseq2/generation/sampling/__init__.py b/src/fairseq2/generation/sampling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/fairseq2/generation/sampling.py b/src/fairseq2/generation/sampling/generator.py similarity index 86% rename from src/fairseq2/generation/sampling.py rename to src/fairseq2/generation/sampling/generator.py index 283e33e37..1015f0766 100644 --- a/src/fairseq2/generation/sampling.py +++ b/src/fairseq2/generation/sampling/generator.py @@ -7,22 +7,27 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Protocol, Sequence, Tuple, Union, final +from collections.abc import Sequence +from typing import final import torch from torch import Tensor from torch.nn.functional import softmax +from typing_extensions import override from fairseq2.data import VocabularyInfo +from fairseq2.error import InternalError from fairseq2.generation.generator import ( AbstractSeq2SeqGenerator, AbstractSequenceGenerator, GenerationCounters, Hypothesis, Seq2SeqGeneratorOutput, + SequenceGenerationError, SequenceGeneratorOutput, StepHook, ) +from fairseq2.generation.sampling.sampler import Sampler from fairseq2.generation.step_processor import StepProcessor from fairseq2.models.decoder import DecoderModel from fairseq2.models.encoder_decoder import EncoderDecoderModel @@ -30,7 +35,6 @@ from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.ops import repeat_interleave from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import override from fairseq2.utils.profiler import Stopwatch @@ -49,8 +53,8 @@ class SamplingSequenceGenerator(AbstractSequenceGenerator): _temperature: float _unk_penalty: float _len_penalty: float - _prefill_chunk_size: Optional[int] - _decode_capacity_increment: Optional[int] + _prefill_chunk_size: int | None + _decode_capacity_increment: int | None _step_processors: Sequence[StepProcessor] def __init__( @@ -61,16 +65,16 @@ def __init__( num_gens: int = 1, min_gen_len: int = 1, max_gen_len: int = 128, - max_seq_len: Optional[int] = None, + max_seq_len: int | None = None, echo_prompt: bool = False, compute_scores: bool = False, normalize_scores: bool = True, temperature: float = 0.6, unk_penalty: float = 0.0, len_penalty: float = 1.0, - prefill_chunk_size: Optional[int] = 512, - decode_capacity_increment: Optional[int] = 16, - step_processors: Optional[Sequence[StepProcessor]] = None, + prefill_chunk_size: int | None = 512, + decode_capacity_increment: int | None = 16, + step_processors: Sequence[StepProcessor] | None = None, ) -> None: """ :param model: @@ -164,7 +168,7 @@ def __init__( @torch.inference_mode() @override def __call__( - self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask] + self, prompt_seqs: Tensor, prompt_padding_mask: PaddingMask | None ) -> SequenceGeneratorOutput: op = _SamplingSequenceGeneratorOp( self._model, @@ -199,7 +203,7 @@ class SamplingSeq2SeqGenerator(AbstractSeq2SeqGenerator): _sampler: Sampler _num_gens: int _min_gen_len: int - _max_gen_len: Tuple[int, int] + _max_gen_len: tuple[int, int] _max_seq_len: int _echo_prompt: bool _compute_scores: bool @@ -207,8 +211,8 @@ class SamplingSeq2SeqGenerator(AbstractSeq2SeqGenerator): _temperature: float _unk_penalty: float _len_penalty: float - _prefill_chunk_size: Optional[int] - _decode_capacity_increment: Optional[int] + _prefill_chunk_size: int | None + _decode_capacity_increment: int | None _step_processors: Sequence[StepProcessor] def __init__( @@ -218,17 +222,17 @@ def __init__( *, num_gens: int = 1, min_gen_len: int = 1, - max_gen_len: Tuple[int, int] = (1, 128), - max_seq_len: Optional[int] = None, + max_gen_len: tuple[int, int] = (1, 128), + max_seq_len: int | None = None, echo_prompt: bool = False, compute_scores: bool = False, normalize_scores: bool = True, temperature: float = 0.6, unk_penalty: float = 0.0, len_penalty: float = 1.0, - prefill_chunk_size: Optional[int] = 512, - decode_capacity_increment: Optional[int] = 16, - step_processors: Optional[Sequence[StepProcessor]] = None, + prefill_chunk_size: int | None = 512, + decode_capacity_increment: int | None = 16, + step_processors: Sequence[StepProcessor] | None = None, ) -> None: """ :param model: @@ -315,9 +319,9 @@ def __init__( def __call__( self, source_seqs: Tensor, - source_padding_mask: Optional[PaddingMask], + source_padding_mask: PaddingMask | None, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, ) -> Seq2SeqGeneratorOutput: # (P, S) encoder_output, encoder_padding_mask = self.model.encode( @@ -375,102 +379,11 @@ def __call__( ) -class Sampler(Protocol): - """Represents a sampling algorithm.""" - - def __call__(self, probs: Tensor) -> Tensor: - """ - :param probs: - The next-step probability of each vocabulary entry. *Shape:* - :math:`(N,V)`, where :math:`N` is the batch size and :math:`V` is - the size of the vocabulary. - """ - - -@final -class TopPSampler(Sampler): - """Selects the next step randomly from the smallest set of candidates for - which the cumulative probability exceeds a specified value p. - - Also known as Nucleus Sampling as described in - :cite:t:`https://doi.org/10.48550/arxiv.1904.09751`. - """ - - _p: float - - def __init__(self, p: float = 0.9) -> None: - """ - :param p: - The cumulative probability threshold. - """ - self._p = p - - def __call__(self, probs: Tensor) -> Tensor: - # Previous operations in the generation like step processors might have - # modified the probabilities. Normalize the distribution. - probs = probs / probs.sum(dim=-1, keepdim=True) - - sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) - - # (N, V) - cumsum_probs = torch.cumsum(sorted_probs, dim=-1) - - mask = (cumsum_probs - sorted_probs) > self._p - - sorted_probs[mask] = 0.0 - - # Normalize. - sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True) - - # (N, 1) - indices = sorted_indices.gather( - dim=-1, index=torch.multinomial(sorted_probs, num_samples=1) - ) - - # (N, 1) -> (N) - return indices.squeeze(-1) # type: ignore[no-any-return] - - -@final -class TopKSampler(Sampler): - """Selects the next step randomly from the k mosty likely candidates.""" - - _k: int - - def __init__(self, k: int) -> None: - """ - :param k: - The number of candidates to select from. - """ - self._k = k - - def __call__(self, probs: Tensor) -> Tensor: - k = min(self._k, probs.size(1)) - - if k == 1: - # (N, 1) - indices = torch.argmax(probs, dim=-1, keepdim=True) - else: - # (N, V) -> (N, K) - topk_probs, topk_indices = torch.topk(probs, k=k, dim=-1, sorted=False) - - # Normalize. - topk_probs /= topk_probs.sum(dim=-1, keepdim=True) - - # (N, 1) - indices = topk_indices.gather( - dim=-1, index=torch.multinomial(topk_probs, num_samples=1) - ) - - # (N, 1) -> (N) - return indices.squeeze(-1) - - class _AbstractSamplingSequenceGeneratorOp(ABC): _sampler: Sampler _eos_idx: int - _pad_idx: Optional[int] - _unk_idx: Optional[int] + _pad_idx: int | None + _unk_idx: int | None _num_gens: int _min_prompt_len: int _max_prompt_len: int @@ -482,23 +395,23 @@ class _AbstractSamplingSequenceGeneratorOp(ABC): _temperature: float _unk_penalty: float _len_penalty: float - _prefill_chunk_size: Optional[int] + _prefill_chunk_size: int | None _step_processors: Sequence[StepProcessor] - _step_hooks: Dict[int, StepHook] + _step_hooks: dict[int, StepHook] _step_nr: int _state_bag: IncrementalStateBag - _prompt_lens: Optional[Tensor] - _prompt_mask: Optional[Tensor] + _prompt_lens: Tensor | None + _prompt_mask: Tensor | None _prompt_indices: Tensor _seqs: Tensor - _step_scores: Optional[Tensor] - _output: List[List[Hypothesis]] + _step_scores: Tensor | None + _output: list[list[Hypothesis]] _counters: GenerationCounters def __init__( self, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, sampler: Sampler, vocab_info: VocabularyInfo, num_gens: int, @@ -511,14 +424,15 @@ def __init__( temperature: float, unk_penalty: float, len_penalty: float, - prefill_chunk_size: Optional[int], - decode_capacity_increment: Optional[int], + prefill_chunk_size: int | None, + decode_capacity_increment: int | None, step_processors: Sequence[StepProcessor], - step_hooks: Dict[int, StepHook], + step_hooks: dict[int, StepHook], ) -> None: self._sampler = sampler - assert vocab_info.eos_idx is not None + if vocab_info.eos_idx is None: + raise InternalError("`vocab_info.eos_idx` is `None`.") self._eos_idx = vocab_info.eos_idx self._pad_idx = vocab_info.pad_idx @@ -526,8 +440,8 @@ def __init__( self._num_gens = num_gens - min_prompt_idx: Union[int, Tensor] - max_prompt_idx: Union[int, Tensor] + min_prompt_idx: int | Tensor + max_prompt_idx: int | Tensor if prompt_padding_mask is None: self._min_prompt_len, min_prompt_idx = prompt_seqs.size(1), 0 @@ -612,7 +526,7 @@ def __init__( self._counters = GenerationCounters() - def __call__(self) -> Tuple[List[List[Hypothesis]], GenerationCounters]: + def __call__(self) -> tuple[list[list[Hypothesis]], GenerationCounters]: self._prepare_state() watch = Stopwatch(start=True, device=self._seqs.device) @@ -677,7 +591,7 @@ def _prefill(self) -> None: probs = softmax(logits, dim=-1, dtype=torch.float32) if probs.isnan().any(): - raise RuntimeError( + raise SequenceGenerationError( "The model has produced one or more NaN probabilities during prefill. The sequence generator cannot continue." ) @@ -729,7 +643,7 @@ def _step(self) -> bool: probs.squeeze_(1) if probs.isnan().any(): - raise RuntimeError( + raise SequenceGenerationError( f"The model has produced one or more NaN probabilities at step {self._step_nr}. The sequence generator cannot continue." ) @@ -758,7 +672,7 @@ def _step(self) -> bool: probs[:, self._eos_idx] = 0 # (N) - vocab_indices = self._sampler(probs) + vocab_indices = self._sampler.sample(probs) # EOS mask of the current step. # (N) @@ -766,7 +680,8 @@ def _step(self) -> bool: # Ignore the generated indices for the prompt sequences. if self._step_nr < self._max_prompt_len: - assert self._prompt_mask is not None + if self._prompt_mask is None: + raise InternalError("`_prompt_mask` is `None`.") # (N) mask = self._prompt_mask[:, self._step_nr] @@ -903,7 +818,7 @@ def __init__( self, model: DecoderModel, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, sampler: Sampler, num_gens: int, min_gen_len: int, @@ -915,10 +830,10 @@ def __init__( temperature: float, unk_penalty: float, len_penalty: float, - prefill_chunk_size: Optional[int], - decode_capacity_increment: Optional[int], + prefill_chunk_size: int | None, + decode_capacity_increment: int | None, step_processors: Sequence[StepProcessor], - step_hooks: Dict[int, StepHook], + step_hooks: dict[int, StepHook], ) -> None: super().__init__( prompt_seqs, @@ -958,15 +873,15 @@ def _decode(self, seqs: Tensor) -> SequenceModelOutput: class _SamplingSeq2SeqGeneratorOp(_AbstractSamplingSequenceGeneratorOp): _model: EncoderDecoderModel _encoder_output: Tensor - _encoder_padding_mask: Optional[PaddingMask] + _encoder_padding_mask: PaddingMask | None def __init__( self, model: EncoderDecoderModel, encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + encoder_padding_mask: PaddingMask | None, prompt_seqs: Tensor, - prompt_padding_mask: Optional[PaddingMask], + prompt_padding_mask: PaddingMask | None, sampler: Sampler, num_gens: int, min_gen_len: int, @@ -978,10 +893,10 @@ def __init__( temperature: float, unk_penalty: float, len_penalty: float, - prefill_chunk_size: Optional[int], - decode_capacity_increment: Optional[int], + prefill_chunk_size: int | None, + decode_capacity_increment: int | None, step_processors: Sequence[StepProcessor], - step_hooks: Dict[int, StepHook], + step_hooks: dict[int, StepHook], ) -> None: super().__init__( prompt_seqs, diff --git a/src/fairseq2/generation/sampling/handler.py b/src/fairseq2/generation/sampling/handler.py new file mode 100644 index 000000000..c6ba477bb --- /dev/null +++ b/src/fairseq2/generation/sampling/handler.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Final, final + +from typing_extensions import override + +from fairseq2.context import Provider +from fairseq2.generation.generator import Seq2SeqGenerator, SequenceGenerator +from fairseq2.generation.handler import ( + Seq2SeqGeneratorHandler, + SequenceGeneratorHandler, +) +from fairseq2.generation.sampling.generator import ( + SamplingSeq2SeqGenerator, + SamplingSequenceGenerator, +) +from fairseq2.generation.sampling.sampler import ( + TOP_P_SAMPLER, + SamplerHandler, + SamplerNotFoundError, + TopPSamplerConfig, +) +from fairseq2.models.decoder import DecoderModel +from fairseq2.models.encoder_decoder import EncoderDecoderModel +from fairseq2.typing import safe_cast +from fairseq2.utils.config import ConfigSectionHandler +from fairseq2.utils.structured import StructureError, structure + +SAMPLING_GENERATOR: Final = "sampling" + + +@dataclass(kw_only=True) +class SamplingConfig: + sampler: SamplerSection = field(default_factory=lambda: SamplerSection()) + """The configuration of the sampler.""" + + min_gen_len: int = 1 + """The minimum generation length.""" + + max_gen_len: int | tuple[int, int] = 2048 + """The maximum generation length.""" + + max_seq_len: int | None = None + """The maximum sequence length including prompt.""" + + echo_prompt: bool = False + """If ``True``, returns generated sequences with prompts appended.""" + + compute_scores: bool = False + """If ``True``, computes scores of generated sequences.""" + + normalize_scores: bool = True + """If ``True``, normalizes scores by lengths of generated sequences.""" + + temperature: float = 1.0 + """The logit temperature.""" + + unk_penalty: float = 0.0 + """The UNK symbol penalty.""" + + len_penalty: float = 1.0 + """The length penalty.""" + + prefill_chunk_size: int | None = 512 + """The prefill will be performed incrementally by chunks of this size.""" + + decode_capacity_increment: int | None = 16 + """The sequence length capacity will be incremented by multiplies of this value.""" + + +@dataclass(kw_only=True) +class SamplerSection: + name: str = TOP_P_SAMPLER + + config: object = field(default_factory=lambda: TopPSamplerConfig()) + + +@final +class SamplerSectionHandler(ConfigSectionHandler): + _sampler_handlers: Provider[SamplerHandler] + + def __init__(self, sampler_handlers: Provider[SamplerHandler]) -> None: + self._sampler_handlers = sampler_handlers + + @override + def process(self, section: object) -> None: + section = safe_cast("section", section, SamplerSection) + + try: + sampler_handler = self._sampler_handlers.get(section.name) + except LookupError: + raise SamplerNotFoundError(section.name) from None + + try: + section.config = structure(section.config, sampler_handler.config_kls) + except StructureError as ex: + raise StructureError( + "`config` cannot be structured. See the nested exception for details." + ) from ex + + +@final +class SamplingSequenceGeneratorHandler(SequenceGeneratorHandler): + _sampler_handlers: Provider[SamplerHandler] + + def __init__(self, sampler_handlers: Provider[SamplerHandler]) -> None: + self._sampler_handlers = sampler_handlers + + @override + def create(self, model: DecoderModel, config: object) -> SequenceGenerator: + config = safe_cast("config", config, SamplingConfig) + + sampler_section = config.sampler + + try: + sampler_handler = self._sampler_handlers.get(sampler_section.name) + except LookupError: + raise SamplerNotFoundError(sampler_section.name) from None + + sampler = sampler_handler.create(sampler_section.config) + + if isinstance(config.max_gen_len, int): + max_gen_len = config.max_gen_len + else: + if config.max_gen_len[0] != 1: + raise ValueError("`max_gen_len` must be an integer.") + + max_gen_len = config.max_gen_len[1] + + return SamplingSequenceGenerator( + model, + sampler, + min_gen_len=config.min_gen_len, + max_gen_len=max_gen_len, + max_seq_len=config.max_seq_len, + echo_prompt=config.echo_prompt, + compute_scores=config.compute_scores, + normalize_scores=config.normalize_scores, + temperature=config.temperature, + unk_penalty=config.unk_penalty, + len_penalty=config.len_penalty, + prefill_chunk_size=config.prefill_chunk_size, + decode_capacity_increment=config.decode_capacity_increment, + ) + + @property + @override + def config_kls(self) -> type: + return SamplingConfig + + +@final +class SamplingSeq2SeqGeneratorHandler(Seq2SeqGeneratorHandler): + _sampler_handlers: Provider[SamplerHandler] + + def __init__(self, sampler_handlers: Provider[SamplerHandler]) -> None: + self._sampler_handlers = sampler_handlers + + @override + def create(self, model: EncoderDecoderModel, config: object) -> Seq2SeqGenerator: + config = safe_cast("config", config, SamplingConfig) + + sampler_section = config.sampler + + try: + sampler_handler = self._sampler_handlers.get(sampler_section.name) + except LookupError: + raise SamplerNotFoundError(sampler_section.name) from None + + sampler = sampler_handler.create(sampler_section.config) + + max_gen_len = config.max_gen_len + + if isinstance(max_gen_len, int): + max_gen_len = (1, max_gen_len) + + return SamplingSeq2SeqGenerator( + model, + sampler, + min_gen_len=config.min_gen_len, + max_gen_len=max_gen_len, + max_seq_len=config.max_seq_len, + echo_prompt=config.echo_prompt, + compute_scores=config.compute_scores, + normalize_scores=config.normalize_scores, + temperature=config.temperature, + unk_penalty=config.unk_penalty, + len_penalty=config.len_penalty, + prefill_chunk_size=config.prefill_chunk_size, + decode_capacity_increment=config.decode_capacity_increment, + ) + + @property + @override + def config_kls(self) -> type: + return SamplingConfig diff --git a/src/fairseq2/generation/sampling/sampler.py b/src/fairseq2/generation/sampling/sampler.py new file mode 100644 index 000000000..f5db6bff5 --- /dev/null +++ b/src/fairseq2/generation/sampling/sampler.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Final, final + +import torch +from torch import Tensor +from typing_extensions import override + +from fairseq2.typing import safe_cast + + +class Sampler(ABC): + """Represents a sampling algorithm.""" + + @abstractmethod + def sample(self, probs: Tensor) -> Tensor: + """ + :param probs: + The next-step probability of each vocabulary entry. *Shape:* + :math:`(N,V)`, where :math:`N` is the batch size and :math:`V` is + the size of the vocabulary. + """ + + +@final +class TopPSampler(Sampler): + """Selects the next step randomly from the smallest set of candidates for + which the cumulative probability exceeds a specified value p. + + Also known as Nucleus Sampling as described in + :cite:t:`https://doi.org/10.48550/arxiv.1904.09751`. + """ + + _p: float + + def __init__(self, p: float = 0.9) -> None: + """ + :param p: + The cumulative probability threshold. + """ + self._p = p + + @override + def sample(self, probs: Tensor) -> Tensor: + # Previous operations in the generation like step processors might have + # modified the probabilities. Normalize the distribution. + probs = probs / probs.sum(dim=-1, keepdim=True) + + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) + + # (N, V) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + + mask = (cumsum_probs - sorted_probs) > self._p + + sorted_probs[mask] = 0.0 + + # Normalize. + sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True) + + # (N, 1) + indices = sorted_indices.gather( + dim=-1, index=torch.multinomial(sorted_probs, num_samples=1) + ) + + # (N, 1) -> (N) + return indices.squeeze(-1) # type: ignore[no-any-return] + + +@final +class TopKSampler(Sampler): + """Selects the next step randomly from the k mosty likely candidates.""" + + _k: int + + def __init__(self, k: int) -> None: + """ + :param k: + The number of candidates to select from. + """ + self._k = k + + @override + def sample(self, probs: Tensor) -> Tensor: + k = min(self._k, probs.size(1)) + + if k == 1: + # (N, 1) + indices = torch.argmax(probs, dim=-1, keepdim=True) + else: + # (N, V) -> (N, K) + topk_probs, topk_indices = torch.topk(probs, k=k, dim=-1, sorted=False) + + # Normalize. + topk_probs /= topk_probs.sum(dim=-1, keepdim=True) + + # (N, 1) + indices = topk_indices.gather( + dim=-1, index=torch.multinomial(topk_probs, num_samples=1) + ) + + # (N, 1) -> (N) + return indices.squeeze(-1) + + +class SamplerHandler(ABC): + @abstractmethod + def create(self, config: object) -> Sampler: + ... + + @property + @abstractmethod + def config_kls(self) -> type: + ... + + +class SamplerNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known sequence generator sampler.") + + self.name = name + + +TOP_P_SAMPLER: Final = "top-p" + + +@dataclass(kw_only=True) +class TopPSamplerConfig: + p: float = 1.0 + + +@final +class TopPSamplerHandler(SamplerHandler): + @override + def create(self, config: object) -> Sampler: + config = safe_cast("config", config, TopPSamplerConfig) + + return TopPSampler(p=config.p) + + @property + @override + def config_kls(self) -> type: + return TopPSamplerConfig + + +TOP_K_SAMPLER: Final = "top-k" + + +@dataclass(kw_only=True) +class TopKSamplerConfig: + k: int = 1 + + +@final +class TopKSamplerHandler(SamplerHandler): + @override + def create(self, config: object) -> Sampler: + config = safe_cast("config", config, TopKSamplerConfig) + + return TopKSampler(k=config.k) + + @property + @override + def config_kls(self) -> type: + return TopKSamplerConfig diff --git a/src/fairseq2/generation/static.py b/src/fairseq2/generation/static.py new file mode 100644 index 000000000..f02eccd04 --- /dev/null +++ b/src/fairseq2/generation/static.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import get_runtime_context +from fairseq2.generation.generator import Seq2SeqGenerator, SequenceGenerator +from fairseq2.generation.handler import ( + Seq2SeqGeneratorHandler, + Seq2SeqGeneratorNotFoundError, + SequenceGeneratorHandler, + SequenceGeneratorNotFoundError, +) +from fairseq2.models.decoder import DecoderModel +from fairseq2.models.encoder_decoder import EncoderDecoderModel +from fairseq2.utils.config import process_config +from fairseq2.utils.structured import structure + + +def create_seq_generator( + name: str, model: DecoderModel, config: object = None +) -> SequenceGenerator: + context = get_runtime_context() + + registry = context.get_registry(SequenceGeneratorHandler) + + try: + handler = registry.get(name) + except LookupError: + raise SequenceGeneratorNotFoundError(name) from None + + if config is None: + try: + config = handler.config_kls() + except TypeError: + raise ValueError( + f"`config` must be specified for the '{name}' sequence generator." + ) from None + else: + config = structure(config, handler.config_kls) + + process_config(config) + + return handler.create(model, config) + + +def create_seq2seq_generator( + name: str, model: EncoderDecoderModel, config: object = None +) -> Seq2SeqGenerator: + context = get_runtime_context() + + registry = context.get_registry(Seq2SeqGeneratorHandler) + + try: + handler = registry.get(name) + except LookupError: + raise Seq2SeqGeneratorNotFoundError(name) from None + + if config is None: + try: + config = handler.config_kls() + except TypeError: + raise ValueError( + f"`config` must be specified for the '{name}' sequence-to-sequence generator." + ) from None + else: + config = structure(config, handler.config_kls) + + process_config(config) + + return handler.create(model, config) diff --git a/src/fairseq2/generation/step_processor.py b/src/fairseq2/generation/step_processor.py index 7d5e8c7f3..0f613e811 100644 --- a/src/fairseq2/generation/step_processor.py +++ b/src/fairseq2/generation/step_processor.py @@ -7,7 +7,8 @@ from __future__ import annotations import sys -from typing import List, Optional, Protocol, Sequence, final +from collections.abc import Sequence +from typing import Protocol, final import torch from torch import Tensor @@ -36,8 +37,8 @@ def __call__(self, seqs: Tensor, probs: Tensor, lprob: bool = False) -> None: class BannedSequenceProcessor(StepProcessor): """Prevents a provided list of banned sequences from being generated.""" - _banned_seqs: Optional[Tensor] - _banned_mask: Optional[Tensor] + _banned_seqs: Tensor | None + _banned_mask: Tensor | None def __init__(self, banned_seqs: Sequence[Tensor]) -> None: """ @@ -55,7 +56,7 @@ def __init__(self, banned_seqs: Sequence[Tensor]) -> None: max_seq_len = 0 min_seq_len = sys.maxsize - seq_lens: List[int] = [] + seq_lens: list[int] = [] for idx, seq in enumerate(banned_seqs): seq_len = len(seq) diff --git a/src/fairseq2/generation/text.py b/src/fairseq2/generation/text.py index a54c8c4d9..90b51e837 100644 --- a/src/fairseq2/generation/text.py +++ b/src/fairseq2/generation/text.py @@ -6,11 +6,13 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple, final +from collections.abc import Sequence +from typing import final from torch import Tensor from fairseq2.data.text import TextTokenDecoder, TextTokenEncoder, TextTokenizer +from fairseq2.error import ContractError from fairseq2.generation.generator import ( Seq2SeqGenerator, Seq2SeqGeneratorOutput, @@ -34,7 +36,7 @@ def __init__( generator: Seq2SeqGenerator, tokenizer: TextTokenizer, task: str, - target_lang: Optional[str] = None, + target_lang: str | None = None, ) -> None: """ :param generator: @@ -48,7 +50,12 @@ def __init__( """ self._generator = generator - device = infer_device(generator.model, name="generator.model") + try: + device = infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex target_text_encoder = tokenizer.create_encoder( task=task, lang=target_lang, mode="target", device=device @@ -65,7 +72,7 @@ def __init__( self._text_decoder = tokenizer.create_decoder() - def __call__(self, source_seq: Tensor) -> Tuple[str, Seq2SeqGeneratorOutput]: + def __call__(self, source_seq: Tensor) -> tuple[str, Seq2SeqGeneratorOutput]: """ :param source_seq: The source sequence. *Shape:* :math:`(S,*)`, where :math:`S` is the @@ -85,8 +92,8 @@ def __call__(self, source_seq: Tensor) -> Tuple[str, Seq2SeqGeneratorOutput]: def batch_convert( self, source_seqs: Tensor, - source_padding_mask: Optional[PaddingMask], - ) -> Tuple[List[str], Seq2SeqGeneratorOutput]: + source_padding_mask: PaddingMask | None, + ) -> tuple[list[str], Seq2SeqGeneratorOutput]: """ :param source_seqs: The source sequences. *Shape:* :math:`(N,S,*)`, where :math:`N` is @@ -110,8 +117,8 @@ def batch_convert( def _do_convert( self, source_seqs: Tensor, - source_padding_mask: Optional[PaddingMask], - ) -> Tuple[List[str], Seq2SeqGeneratorOutput]: + source_padding_mask: PaddingMask | None, + ) -> tuple[list[str], Seq2SeqGeneratorOutput]: """A subclass should call this method for actual text conversion. :param source_seqs: @@ -135,12 +142,12 @@ def _do_convert( source_seqs, source_padding_mask, target_prefix_seqs, None ) - texts: List[str] = [] + texts: list[str] = [] for idx, hypotheses in enumerate(generator_output.hypotheses): if len(hypotheses) == 0: - raise RuntimeError( - f"The sequence generator returned no hypothesis at index {idx}. Please file a bug report." + raise ContractError( + f"The sequence generator returned no hypothesis at index {idx}." ) texts.append(self._text_decoder(hypotheses[0].seq)) @@ -155,16 +162,16 @@ class TextTranslator: _converter: SequenceToTextConverter _pad_idx: int _source_text_encoder: TextTokenEncoder - _max_source_len: Optional[int] + _max_source_len: int | None def __init__( self, generator: Seq2SeqGenerator, tokenizer: TextTokenizer, - source_lang: Optional[str] = None, - target_lang: Optional[str] = None, + source_lang: str | None = None, + target_lang: str | None = None, *, - max_source_len: Optional[int] = None, + max_source_len: int | None = None, ) -> None: """ :param generator: @@ -191,7 +198,12 @@ def __init__( self._pad_idx = pad_idx - device = infer_device(generator.model, name="generator.model") + try: + device = infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex self._source_text_encoder = tokenizer.create_encoder( task="translation", lang=source_lang, mode="source", device=device @@ -204,7 +216,7 @@ def __init__( self._max_source_len = max_source_len - def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: + def __call__(self, source_text: str) -> tuple[str, Seq2SeqGeneratorOutput]: """ :param source_text: The text in the source language. @@ -222,7 +234,7 @@ def __call__(self, source_text: str) -> Tuple[str, Seq2SeqGeneratorOutput]: def batch_translate( self, source_texts: Sequence[str] - ) -> Tuple[List[str], Seq2SeqGeneratorOutput]: + ) -> tuple[list[str], Seq2SeqGeneratorOutput]: """ :param source_texts: The texts in the source language. @@ -263,12 +275,17 @@ def __init__(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> No """ self._generator = generator - device = infer_device(generator.model, name="generator.model") + try: + device = infer_device(generator.model) + except ValueError as ex: + raise ValueError( + "The device of `generator.model` is not valid. See the nested exception for details." + ) from ex self._text_encoder = tokenizer.create_encoder(mode="prompt", device=device) self._text_decoder = tokenizer.create_decoder() - def __call__(self, prompt: str) -> Tuple[str, SequenceGeneratorOutput]: + def __call__(self, prompt: str) -> tuple[str, SequenceGeneratorOutput]: """ :param prompt: The text prompt. @@ -287,7 +304,7 @@ def __call__(self, prompt: str) -> Tuple[str, SequenceGeneratorOutput]: def batch_complete( self, prompts: Sequence[str] - ) -> Tuple[List[str], SequenceGeneratorOutput]: + ) -> tuple[list[str], SequenceGeneratorOutput]: """ :param prompts: The text prompts. @@ -308,16 +325,16 @@ def batch_complete( return self._do_complete(prompt_seqs, prompt_padding_mask) def _do_complete( - self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask] - ) -> Tuple[List[str], SequenceGeneratorOutput]: + self, prompt_seqs: Tensor, prompt_padding_mask: PaddingMask | None + ) -> tuple[list[str], SequenceGeneratorOutput]: generator_output = self._generator(prompt_seqs, prompt_padding_mask) - texts: List[str] = [] + texts: list[str] = [] for idx, hypotheses in enumerate(generator_output.hypotheses): if len(hypotheses) == 0: - raise RuntimeError( - f"The sequence generator returned no hypothesis at index {idx}. Please file a bug report." + raise ContractError( + f"The sequence generator returned no hypothesis at index {idx}." ) texts.append(self._text_decoder(hypotheses[0].seq)) diff --git a/src/fairseq2/generation/utils.py b/src/fairseq2/generation/utils.py deleted file mode 100644 index d8945a98b..000000000 --- a/src/fairseq2/generation/utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from typing import Optional, final - -from rich.console import Console -from torch import Tensor - -from fairseq2.console import get_console -from fairseq2.data.text import TextTokenDecoder - - -@final -class _StdOutPrintHook: - """Prints characters produced by a sequence generator to stdout.""" - - _text_decoder: TextTokenDecoder - _prev_text_len: int - _first_print: bool - _console: Console - - def __init__(self, text_decoder: TextTokenDecoder) -> None: - self._text_decoder = text_decoder - self._prev_text_len = 0 - self._first_print = True - - self._console = get_console() - - def __call__( - self, - prompt_indices: Tensor, - seqs: Tensor, - step_scores: Optional[Tensor], - prefill: bool, - ) -> None: - assert len(prompt_indices) == 1 - - # Do not print anything during prompt prefill. - if prefill: - return - - text = self._text_decoder(seqs[0]) - - text_len = len(text) - - # If this is our first print, determine the length of the prompt text. - if self._prev_text_len == 0: - prev_text = self._text_decoder(seqs[0][:-1]) - - prev_text_len = len(prev_text) - else: - prev_text_len = self._prev_text_len - - # Cache the length of the text so that we don't have to decode it twice - # in the next step. - self._prev_text_len = text_len - - # No need to print if we decoded a control symbol (e.g. EOS). - if text_len == prev_text_len: - return - - text = text[prev_text_len - text_len :] - - # Some models output several whitespace characters after the prompt. - if self._first_print: - text = text.lstrip() - if not text: - return - - self._first_print = False - - self._console.print(text, highlight=False, end="") diff --git a/src/fairseq2/logging.py b/src/fairseq2/logging.py index 3367096dd..31d94b3fb 100644 --- a/src/fairseq2/logging.py +++ b/src/fairseq2/logging.py @@ -8,7 +8,12 @@ import logging from logging import Logger, getLogger -from typing import Any, Final, Optional, final +from typing import Any, Final, final + + +def get_log_writer(name: str | None = None) -> LogWriter: + """Return a :class:`LogWriter` for the logger with the specified name.""" + return LogWriter(getLogger(name)) @final @@ -27,41 +32,41 @@ def __init__(self, logger: Logger) -> None: self._logger = logger def debug( - self, msg: Any, *args: Any, highlight: bool = False, **kwargs: Any + self, message: str, *args: Any, highlight: bool = False, **kwargs: Any ) -> None: """Log a message with level ``DEBUG``.""" - self._write(logging.DEBUG, msg, args, kwargs, highlight) + self._write(logging.DEBUG, message, args, kwargs, highlight) def info( - self, msg: Any, *args: Any, highlight: bool = False, **kwargs: Any + self, message: str, *args: Any, highlight: bool = False, **kwargs: Any ) -> None: """Log a message with level ``INFO``.""" - self._write(logging.INFO, msg, args, kwargs, highlight) + self._write(logging.INFO, message, args, kwargs, highlight) def warning( - self, msg: Any, *args: Any, highlight: bool = False, **kwargs: Any + self, message: str, *args: Any, highlight: bool = False, **kwargs: Any ) -> None: """Log a message with level ``WARNING``.""" - self._write(logging.WARNING, msg, args, kwargs, highlight) + self._write(logging.WARNING, message, args, kwargs, highlight) def error( - self, msg: Any, *args: Any, highlight: bool = False, **kwargs: Any + self, message: str, *args: Any, highlight: bool = False, **kwargs: Any ) -> None: """Log a message with level ``ERROR``.""" - self._write(logging.ERROR, msg, args, kwargs, highlight) + self._write(logging.ERROR, message, args, kwargs, highlight) def exception( - self, msg: Any, *args: Any, highlight: bool = False, **kwargs: Any + self, message: str, *args: Any, highlight: bool = False, **kwargs: Any ) -> None: """Log a message with level ``ERROR``.""" - self._write(logging.ERROR, msg, args, kwargs, highlight, exc_info=True) + self._write(logging.ERROR, message, args, kwargs, highlight, exc_info=True) def _write( self, level: int, - msg: Any, - args: Any, - kwargs: Any, + message: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], highlight: bool, exc_info: bool = False, ) -> None: @@ -69,11 +74,11 @@ def _write( if not self._logger.isEnabledFor(level): return - msg = str(msg).format(*args, **kwargs) + message = str(message).format(*args, **kwargs) extra = None if highlight else self._NO_HIGHLIGHT - self._logger.log(level, msg, extra=extra, exc_info=exc_info) + self._logger.log(level, message, exc_info=exc_info, extra=extra) def is_enabled_for(self, level: int) -> bool: """Return ``True`` if the writer is enabled for ``level``.""" @@ -88,6 +93,4 @@ def is_enabled_for_info(self) -> bool: return self._logger.isEnabledFor(logging.INFO) -def get_log_writer(name: Optional[str] = None) -> LogWriter: - """Return a :class:`LogWriter` for the logger with the specified name.""" - return LogWriter(getLogger(name)) +log = get_log_writer("fairseq2") diff --git a/src/fairseq2/metrics/__init__.py b/src/fairseq2/metrics/__init__.py index a9bf5d2cb..fde36b121 100644 --- a/src/fairseq2/metrics/__init__.py +++ b/src/fairseq2/metrics/__init__.py @@ -14,6 +14,7 @@ from fairseq2.metrics.recorder import LogMetricRecorder as LogMetricRecorder from fairseq2.metrics.recorder import MetricRecorder as MetricRecorder from fairseq2.metrics.recorder import TensorBoardRecorder as TensorBoardRecorder +from fairseq2.metrics.recorder import WandbRecorder as WandbRecorder from fairseq2.metrics.recorder import format_as_float as format_as_float from fairseq2.metrics.recorder import format_as_int as format_as_int from fairseq2.metrics.recorder import format_as_seconds as format_as_seconds diff --git a/src/fairseq2/metrics/aggregation.py b/src/fairseq2/metrics/aggregation.py index 0fa69634f..55e17a952 100644 --- a/src/fairseq2/metrics/aggregation.py +++ b/src/fairseq2/metrics/aggregation.py @@ -6,25 +6,20 @@ from __future__ import annotations -from typing import Iterable, Optional, Union - import torch from torch import Tensor from torcheval.metrics import Max as MaxBase from torcheval.metrics import Mean as MeanBase -from torcheval.metrics import Metric from torcheval.metrics import Min as MinBase from torcheval.metrics import Sum as SumBase -from typing_extensions import Self - -from fairseq2.typing import Device, override +from typing_extensions import Self, override class Min(MinBase): """See :class:`MinBase`.""" @override - def update(self, input_: Union[int, float, Tensor]) -> Self: + def update(self, input_: int | float | Tensor) -> Self: if isinstance(input_, (int, float)): input_ = torch.tensor(input_) @@ -37,7 +32,7 @@ class Max(MaxBase): """See :class:`MaxBase`.""" @override - def update(self, input_: Union[int, float, Tensor]) -> Self: + def update(self, input_: int | float | Tensor) -> Self: if isinstance(input_, (int, float)): input_ = torch.tensor(input_) @@ -52,9 +47,9 @@ class Mean(MeanBase): @override def update( self, - input_: Union[int, float, Tensor], + input_: int | float | Tensor, *, - weight: Union[int, float, Tensor] = 1.0, + weight: int | float | Tensor = 1.0, ) -> Self: if isinstance(input_, (int, float)): input_ = torch.tensor(input_) @@ -70,9 +65,9 @@ class Sum(SumBase): @override def update( self, - input_: Union[int, float, Tensor], + input_: int | float | Tensor, *, - weight: Union[int, float, Tensor] = 1.0, + weight: int | float | Tensor = 1.0, ) -> Self: if isinstance(input_, (int, float)): input_ = torch.tensor(input_) @@ -80,37 +75,3 @@ def update( super().update(input_, weight=weight) return self - - -class MaxSum(Metric[Tensor]): - """Calculate the sum of all elements in all the input tensors locally and - take the maximum value when merged with other metrics.""" - - sum_: Tensor - - def __init__(self, *, device: Optional[Device] = None) -> None: - super().__init__(device=device) - - sum_ = torch.zeros((), device=device, dtype=torch.int64) - - self._add_state("sum_", sum_) - - @override - @torch.inference_mode() - def update(self, input_: Union[int, Tensor]) -> Self: - self.sum_ += input_ - - return self - - @override - @torch.inference_mode() - def compute(self) -> Tensor: - return self.sum_ - - @override - @torch.inference_mode() - def merge_state(self, metrics: Iterable[MaxSum]) -> Self: - for metric in metrics: - self.sum_ = torch.max(self.sum_, metric.sum_.to(self.device)) - - return self diff --git a/src/fairseq2/metrics/bag.py b/src/fairseq2/metrics/bag.py index 9d1082c9b..1b03dafaa 100644 --- a/src/fairseq2/metrics/bag.py +++ b/src/fairseq2/metrics/bag.py @@ -7,12 +7,14 @@ from __future__ import annotations import logging +from collections.abc import Mapping, Sequence from copy import deepcopy -from typing import Any, Dict, Mapping, Optional, Sequence, final +from typing import Any, final from torcheval.metrics import Metric from torcheval.metrics.toolkit import sync_and_compute_collection +from fairseq2.error import ContractError, InternalError, InvalidOperationError from fairseq2.gang import Gang @@ -20,9 +22,9 @@ class MetricBag: """Holds a collection of training or validation metrics.""" _gang: Gang - _metrics: Dict[str, Metric[Any]] - _persistent_metrics: Dict[str, Metric[Any]] - _original_metrics: Optional[Dict[str, Metric[Any]]] + _metrics: dict[str, Metric[Any]] + _persistent_metrics: dict[str, Metric[Any]] + _original_metrics: dict[str, Metric[Any]] | None def __init__(self, gang: Gang) -> None: """ @@ -97,15 +99,20 @@ def begin_updates(self) -> None: or ``rollback_updates()``. """ if self._original_metrics is not None: - raise ValueError("`begin_updates()` has already been called.") + raise InvalidOperationError("`begin_updates()` has already been called.") - self._original_metrics = deepcopy(self._metrics) + try: + self._original_metrics = deepcopy(self._metrics) + except Exception as ex: + raise ContractError( + "The metrics in the bag cannot be copied. See the nested exception for details." + ) from ex @final def commit_updates(self) -> None: """Commit pending metric updates.""" if self._original_metrics is None: - raise ValueError("`begin_updates()` must be called first.") + raise InvalidOperationError("`begin_updates()` must be called first.") self._original_metrics = None @@ -113,7 +120,7 @@ def commit_updates(self) -> None: def rollback_updates(self) -> None: """Discard pending metric updates and rollback to the original state.""" if self._original_metrics is None: - raise ValueError("`begin_updates()` must be called first.") + raise InvalidOperationError("`begin_updates()` must be called first.") self._metrics, self._original_metrics = self._original_metrics, None @@ -135,11 +142,11 @@ def reset_non_persistent_metrics(self) -> None: metric.reset() @final - def sync_and_compute_metrics(self) -> Optional[Dict[str, Any]]: + def sync_and_compute_metrics(self) -> dict[str, object] | None: """Sync the metrics across all processes and compute their values.""" return sync_and_compute_metrics([self]) - def process_metric_values(self, values: Dict[str, Any]) -> None: + def process_metric_values(self, values: dict[str, object]) -> None: """Process metric ``values``.""" @property @@ -148,8 +155,8 @@ def metrics(self) -> Mapping[str, Metric[Any]]: return self._metrics @final - def state_dict(self) -> Dict[str, Any]: - state_dict = {} + def state_dict(self) -> dict[str, object]: + state_dict: dict[str, object] = {} for name, metric in self._persistent_metrics.items(): state_dict[name] = metric.state_dict() @@ -157,14 +164,41 @@ def state_dict(self) -> Dict[str, Any]: return state_dict @final - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: - if self._persistent_metrics.keys() != state_dict.keys(): - raise ValueError( - f"`state_dict` must contain metrics {list(self._persistent_metrics.keys())}, but contains {list(state_dict.keys())} instead." - ) + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: + state_keys = set(state_dict.keys()) + + metric_names = set(self._persistent_metrics.keys()) + + if metric_names != state_keys: + missing_metrics = metric_names - state_keys + if missing_metrics: + s = ", ".join(sorted(missing_metrics)) + + raise ValueError( + f"`state_dict` must contain the states of the following metric(s): {s}" + ) + + extra_keys = state_keys - metric_names + if extra_keys: + s = ", ".join(sorted(extra_keys)) + + raise ValueError( + f"`state_dict` must contain only the states of the metrics of this bag, but it contains the following unexpected key(s): {s}" + ) for name, metric in self._persistent_metrics.items(): - metric.load_state_dict(state_dict[name]) + metric_state_dict = state_dict[name] + if not isinstance(metric_state_dict, dict): + raise TypeError( + f"`state_dict['{name}']` must be of type `dict`, but is of type `{type(metric_state_dict)}` instead." + ) + + try: + metric.load_state_dict(metric_state_dict) + except (RuntimeError, ValueError) as ex: + raise ValueError( + f"`state_dict['{name}']` is not a valid `{type(metric)}` state. See the nested exception for details." + ) from ex metric.to(self._gang.device) @@ -181,7 +215,7 @@ def reset_non_persistent_metrics(bags: Sequence[MetricBag]) -> None: bag.reset_non_persistent_metrics() -def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> Optional[Dict[str, Any]]: +def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> dict[str, object] | None: """Sync the metrics across all processes and and compute their values.""" if not bags: return None @@ -210,7 +244,8 @@ def sync_and_compute_metrics(bags: Sequence[MetricBag]) -> Optional[Dict[str, An logging.disable(logging.NOTSET) if gang.rank == 0: - assert values is not None + if values is None: + raise InternalError("`values` is `None`.") def strip_underscore(s: str) -> str: if s.startswith("_"): diff --git a/src/fairseq2/metrics/recorder.py b/src/fairseq2/metrics/recorder.py index 6eef78e78..ddb31ab74 100644 --- a/src/fairseq2/metrics/recorder.py +++ b/src/fairseq2/metrics/recorder.py @@ -10,36 +10,32 @@ import math import re from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from datetime import datetime from functools import partial from logging import Logger from pathlib import Path from string import capwords -from typing import ( - Any, - Callable, - Dict, - Final, - Mapping, - Optional, - Sequence, - TextIO, - Union, - final, -) +from typing import Final, TextIO, final from torch import Tensor +from typing_extensions import override -from fairseq2.logging import LogWriter, get_log_writer -from fairseq2.typing import override +from fairseq2.error import AlreadyExistsError +from fairseq2.logging import LogWriter, log -def format_as_int(value: Any, *, postfix: Optional[str] = None) -> str: +def format_as_int(value: object, *, postfix: str | None = None) -> str: """Format metric ``value`` as integer.""" - try: - i = int(value) - except ValueError: + if isinstance(value, int): + i = value + elif isinstance(value, (str, Tensor, float)): + try: + i = int(value) + except ValueError: + return f"{value}" + else: return f"{value}" s = "<1" if i == 0 and isinstance(value, float) else f"{i:,}" @@ -54,13 +50,20 @@ def format_as_int(value: Any, *, postfix: Optional[str] = None) -> str: """Format metric ``value`` as duration in seconds.""" -def format_as_float(value: Any, *, postfix: Optional[str] = None) -> str: +def format_as_float(value: object, *, postfix: str | None = None) -> str: """Format metric ``value`` as float.""" - try: - s = f"{float(value):g}" - except ValueError: + if isinstance(value, float): + f = value + elif isinstance(value, (str, Tensor, int)): + try: + f = float(value) + except ValueError: + return f"{value}" + else: return f"{value}" + s = f"{f:g}" + if postfix: s += postfix @@ -70,15 +73,20 @@ def format_as_float(value: Any, *, postfix: Optional[str] = None) -> str: _UNITS: Final = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"] -def format_as_byte_size(value: Any) -> str: +def format_as_byte_size(value: object) -> str: """Format metric ``value`` in byte units.""" - unit_idx = 0 - - try: - size = float(value) - except ValueError: + if isinstance(value, float): + size = value + elif isinstance(value, (str, Tensor, int)): + try: + size = float(value) + except ValueError: + return f"{value}" + else: return f"{value}" + unit_idx = 0 + if not math.isfinite(size) or size <= 0.0: return "0 B" @@ -97,19 +105,28 @@ def format_as_byte_size(value: Any) -> str: class _MetricFormatter: display_name: str priority: int - fn: Callable[[Any], str] = str + fn: Callable[[object], str] = str log: bool = True -_metric_formatters: Dict[str, _MetricFormatter] = { +_metric_formatters: dict[str, _MetricFormatter] = { # fmt: off + "loss": _MetricFormatter("Loss", 90, format_as_float), + "contrastive_loss": _MetricFormatter("Contrastive Loss", 100, format_as_float), "ctc_loss": _MetricFormatter("CTC Loss", 100, format_as_float), + "diversity_loss": _MetricFormatter("Diversity Loss", 100, format_as_float), "nll_loss": _MetricFormatter("NLL Loss", 100, format_as_float), + "feature_penalty": _MetricFormatter("Feature Penalty", 110, format_as_float), + "accuracy": _MetricFormatter("Accuracy", 200, format_as_float), "bleu": _MetricFormatter("BLEU", 200, format_as_float), "chrf": _MetricFormatter("chrF++", 200, format_as_float), "uer": _MetricFormatter("Unit Error Rate (UER)", 200, format_as_float), "wer": _MetricFormatter("Word Error Rate (WER)", 200, format_as_float), + "code_perplexity": _MetricFormatter("Code Perplexity", 210, format_as_float), + "prob_perplexity": _MetricFormatter("Prob Perplexity", 210, format_as_float), + "temperature": _MetricFormatter("Temperature", 220, format_as_float), "gradient_norm": _MetricFormatter("Gradient Norm", 300, format_as_float), + "data_epoch": _MetricFormatter("Data Epoch", 490, format_as_int), "elapsed_time": _MetricFormatter("Elapsed Time", 500, format_as_seconds), "wall_time": _MetricFormatter("Wall Time", 510, format_as_seconds), "lr": _MetricFormatter("Learning Rate", 700, format_as_float), @@ -142,7 +159,7 @@ def register_metric_formatter( name: str, display_name: str, priority: int, - fn: Callable[[Any], str], + fn: Callable[[object], str], *, log: bool = True, overwrite: bool = False, @@ -163,14 +180,12 @@ def register_metric_formatter( If ``True``, overwrites any existing metric formatter with the same name. """ if name in _metric_formatters and not overwrite: - raise ValueError( - f"`name` must be a unique metric name, but '{name}' is already registered." - ) + raise AlreadyExistsError(f"'{name}' is already a registered metric name.") _metric_formatters[name] = _MetricFormatter(display_name, priority, fn, log) -def format_metric_value(name: str, value: Any) -> str: +def format_metric_value(name: str, value: object) -> str: """Format the specified metric along with its value as a string.""" formatter = _metric_formatters.get(name) if formatter is None: @@ -186,8 +201,8 @@ class MetricRecorder(ABC): def record_metrics( self, run: str, - values: Mapping[str, Any], - step_nr: Optional[int] = None, + values: Mapping[str, object], + step_nr: int | None = None, *, flush: bool = True, ) -> None: @@ -208,11 +223,15 @@ def close(self) -> None: """Close the recorder.""" +class MetricRecordError(Exception): + pass + + def record_metrics( recorders: Sequence[MetricRecorder], run: str, - values: Mapping[str, Any], - step_nr: Optional[int] = None, + values: Mapping[str, object], + step_nr: int | None = None, *, flush: bool = True, ) -> None: @@ -239,7 +258,7 @@ class LogMetricRecorder(MetricRecorder): _log: LogWriter - def __init__(self, log: Union[LogWriter, Logger]) -> None: + def __init__(self, log: LogWriter | Logger) -> None: """ :param log: The log writer or logger to use. @@ -253,8 +272,8 @@ def __init__(self, log: Union[LogWriter, Logger]) -> None: def record_metrics( self, run: str, - values: Mapping[str, Any], - step_nr: Optional[int] = None, + values: Mapping[str, object], + step_nr: int | None = None, *, flush: bool = True, ) -> None: @@ -302,7 +321,7 @@ class JsonFileMetricRecorder(MetricRecorder): _RUN_PART_REGEX: Final = re.compile("^[-_a-zA-Z0-9]+$") _output_dir: Path - _streams: Dict[str, TextIO] + _streams: dict[str, TextIO] def __init__(self, output_dir: Path) -> None: """ @@ -317,8 +336,8 @@ def __init__(self, output_dir: Path) -> None: def record_metrics( self, run: str, - values: Mapping[str, Any], - step_nr: Optional[int] = None, + values: Mapping[str, object], + step_nr: int | None = None, *, flush: bool = True, ) -> None: @@ -344,19 +363,20 @@ def record_metrics( # Sort by priority and display name. values_and_formatters.sort(key=lambda p: (p[1].priority, p[1].display_name)) - def sanitize(value: Any, formatter: _MetricFormatter) -> Any: + def sanitize(value: object, formatter: _MetricFormatter) -> object: if isinstance(value, Tensor): value = value.item() if formatter.fn is format_as_int: - try: - value = int(value) - except ValueError: - pass + if isinstance(value, (str, Tensor, float)): + try: + value = int(value) + except ValueError: + pass return value - output: Dict[str, Any] = {"Time": datetime.utcnow().isoformat()} + output: dict[str, object] = {"Time": datetime.utcnow().isoformat()} if step_nr is not None: output["Step"] = step_nr @@ -364,12 +384,17 @@ def sanitize(value: Any, formatter: _MetricFormatter) -> Any: for value, formatter in values_and_formatters: output[formatter.display_name] = sanitize(value, formatter) - json.dump(output, stream, indent=None) + try: + json.dump(output, stream, indent=None) - stream.write("\n") + stream.write("\n") - if flush: - stream.flush() + if flush: + stream.flush() + except OSError as ex: + raise MetricRecordError( + f"The metric values of the '{run}' cannot be saved to JSON file. See the nested exception for details." + ) from ex def _get_stream(self, run: str) -> TextIO: try: @@ -382,15 +407,15 @@ def _get_stream(self, run: str) -> TextIO: try: file.parent.mkdir(parents=True, exist_ok=True) except OSError as ex: - raise RuntimeError( - f"The metric directory ({file.parent}) cannot be created. See nested exception for details." + raise MetricRecordError( + f"The '{file.parent}' metric directory cannot be created. See the nested exception for details." ) from ex try: fp = file.open("a") except OSError as ex: - raise RuntimeError( - f"The metric file ({file}) cannot be created. See nested exception for details." + raise MetricRecordError( + f"The '{file}' metric file for the '{run} run cannot be created. See the nested exception for details." ) from ex self._streams[run] = fp @@ -417,29 +442,26 @@ def close(self) -> None: class TensorBoardRecorder(MetricRecorder): """Records metric values to TensorBoard.""" - _log_dir: Path - _writers: Dict[str, SummaryWriter] + _output_dir: Path + _writers: dict[str, SummaryWriter] - def __init__(self, log_dir: Path) -> None: + def __init__(self, output_dir: Path) -> None: """ - :param log_dir: + :param output_dir: The base directory under which to store the TensorBoard files. """ if not has_tensorboard: - log = get_log_writer(__name__) - log.warning("tensorboard not found. Please install it with `pip install tensorboard`.") # fmt: skip - self._log_dir = log_dir - + self._output_dir = output_dir self._writers = {} @override def record_metrics( self, run: str, - values: Mapping[str, Any], - step_nr: Optional[int] = None, + values: Mapping[str, object], + step_nr: int | None = None, *, flush: bool = True, ) -> None: @@ -447,26 +469,30 @@ def record_metrics( if writer is None: return - for name, value in values.items(): - formatter = _metric_formatters.get(name) - if formatter is None: - display_name = name - else: - display_name = formatter.display_name - - writer.add_scalar(display_name, value, step_nr) - - if flush: - writer.flush() + try: + for name, value in values.items(): + formatter = _metric_formatters.get(name) + if formatter is None: + display_name = name + else: + display_name = formatter.display_name + + writer.add_scalar(display_name, value, step_nr) + + if flush: + writer.flush() + except RuntimeError as ex: + raise MetricRecordError( + f"The metric values of the '{run}' cannot be saved to TensorBoard. See the nested exception for details." + ) from ex - def _get_writer(self, run: str) -> Optional[SummaryWriter]: + def _get_writer(self, run: str) -> SummaryWriter | None: if not has_tensorboard: return None - try: - writer = self._writers[run] - except KeyError: - writer = SummaryWriter(self._log_dir.joinpath(run)) + writer = self._writers.get(run) + if writer is None: + writer = SummaryWriter(self._output_dir.joinpath(run)) self._writers[run] = writer @@ -478,3 +504,65 @@ def close(self) -> None: writer.close() self._writers.clear() + + +try: + import wandb # type: ignore[import-not-found] +except ImportError: + has_wandb = False +else: + has_wandb = True + + +@final +class WandbRecorder(MetricRecorder): + """Records metric values to Weights & Biases.""" + + def __init__(self, project: str, name: str, output_dir: Path) -> None: + """ + :param project: The W&B project name. + :param name: The run name. + :param output_dir: The base directory under which to store the W&B files. + + In order to use W&B, run `wandb login` from the command line and enter + the API key when prompted. + """ + if not has_wandb: + log.warning("wandb not found. Please install it with `pip install wandb`.") # fmt: skip + + self._run = None + else: + self._run = wandb.init( + project=project, name=name, dir=output_dir.parent, resume="allow" + ) + + @override + def record_metrics( + self, + run: str, + values: Mapping[str, object], + step_nr: int | None = None, + *, + flush: bool = True, + ) -> None: + if self._run is None: + return + + for name, value in values.items(): + formatter = _metric_formatters.get(name) + if formatter is None: + display_name = name + else: + display_name = formatter.display_name + + try: + self._run.log({display_name: value}, step=step_nr) + except RuntimeError as ex: + raise MetricRecordError( + f"The metric values of the '{run}' cannot be saved to Weights & Biases. See the nested exception for details." + ) from ex + + @override + def close(self) -> None: + if self._run is not None: + self._run.finish() diff --git a/src/fairseq2/metrics/text/bleu.py b/src/fairseq2/metrics/text/bleu.py index 187507051..ef80b03d9 100644 --- a/src/fairseq2/metrics/text/bleu.py +++ b/src/fairseq2/metrics/text/bleu.py @@ -6,16 +6,17 @@ from __future__ import annotations -from typing import Iterable, Optional, Sequence, final +from collections.abc import Iterable, Sequence +from typing import final import torch from sacrebleu import corpus_bleu from sacrebleu.metrics.bleu import BLEU, MAX_NGRAM_ORDER from torch import Tensor from torcheval.metrics import Metric -from typing_extensions import Self +from typing_extensions import Self, override -from fairseq2.typing import Device, override +from fairseq2.typing import Device @final @@ -27,7 +28,7 @@ class BleuMetric(Metric[Tensor]): valid_ngrams: Tensor total_ngrams: Tensor - def __init__(self, *, device: Optional[Device] = None) -> None: + def __init__(self, *, device: Device | None = None) -> None: super().__init__(device=device) dtype = torch.int64 diff --git a/src/fairseq2/metrics/text/chrf.py b/src/fairseq2/metrics/text/chrf.py index 462baf538..5009f468b 100644 --- a/src/fairseq2/metrics/text/chrf.py +++ b/src/fairseq2/metrics/text/chrf.py @@ -6,15 +6,16 @@ from __future__ import annotations -from typing import Final, Iterable, Optional, Sequence, final +from collections.abc import Iterable, Sequence +from typing import Final, final import torch from sacrebleu.metrics.chrf import CHRF from torch import Tensor from torcheval.metrics import Metric -from typing_extensions import Self +from typing_extensions import Self, override -from fairseq2.typing import Device, override +from fairseq2.typing import Device @final @@ -26,7 +27,7 @@ class ChrfMetric(Metric[Tensor]): stats: Tensor - def __init__(self, *, device: Optional[Device] = None) -> None: + def __init__(self, *, device: Device | None = None) -> None: super().__init__(device=device) stats_len = 3 * (self.CHAR_ORDER + self.WORD_ORDER) diff --git a/src/fairseq2/metrics/text/wer.py b/src/fairseq2/metrics/text/wer.py index a258b33e9..dcf650375 100644 --- a/src/fairseq2/metrics/text/wer.py +++ b/src/fairseq2/metrics/text/wer.py @@ -6,20 +6,21 @@ from __future__ import annotations -from typing import Iterable, Optional, Sequence, Tuple, final +from collections.abc import Iterable, Sequence +from typing import final import editdistance import torch from torch import Tensor from torcheval.metrics import Metric -from typing_extensions import Self +from typing_extensions import Self, override from fairseq2.nn.padding import PaddingMask, get_seq_lens -from fairseq2.typing import Device, override +from fairseq2.typing import Device @final -class WerMetric(Metric[Tuple[Tensor, Tensor]]): +class WerMetric(Metric[tuple[Tensor, Tensor]]): """Computes the WER (Word Error Rate).""" unit_err: Tensor @@ -27,7 +28,7 @@ class WerMetric(Metric[Tuple[Tensor, Tensor]]): word_err: Tensor word_len: Tensor - def __init__(self, *, device: Optional[Device] = None) -> None: + def __init__(self, *, device: Device | None = None) -> None: super().__init__(device=device) dtype = torch.int64 @@ -50,10 +51,10 @@ def update( self, refs: Sequence[str], ref_seqs: Tensor, - ref_padding_mask: Optional[PaddingMask], + ref_padding_mask: PaddingMask | None, hyps: Sequence[str], hyp_seqs: Tensor, - hyp_padding_mask: Optional[PaddingMask], + hyp_padding_mask: PaddingMask | None, ) -> Self: """ :param refs: @@ -98,13 +99,13 @@ def update( @override @torch.inference_mode() - def compute(self) -> Tuple[Tensor, Tensor]: + def compute(self) -> tuple[Tensor, Tensor]: if self.unit_len and self.word_len: uer = self.unit_err * 100.0 / self.unit_len wer = self.word_err * 100.0 / self.word_len else: - uer = torch.zeros((), dtype=torch.float32) - wer = torch.zeros((), dtype=torch.float32) + uer = torch.tensor(-1.0, dtype=torch.float32) + wer = torch.tensor(-1.0, dtype=torch.float32) return uer, wer diff --git a/src/fairseq2/models/__init__.py b/src/fairseq2/models/__init__.py index b89d3c74d..583520139 100644 --- a/src/fairseq2/models/__init__.py +++ b/src/fairseq2/models/__init__.py @@ -6,22 +6,17 @@ from __future__ import annotations -from fairseq2.models.chatbot import ChatbotFactory as ChatbotFactory -from fairseq2.models.chatbot import DelegatingChatbotFactory as DelegatingChatbotFactory -from fairseq2.models.chatbot import create_chatbot as create_chatbot from fairseq2.models.config_loader import ModelConfigLoader as ModelConfigLoader from fairseq2.models.config_loader import ( StandardModelConfigLoader as StandardModelConfigLoader, ) from fairseq2.models.config_loader import get_model_family as get_model_family from fairseq2.models.config_loader import is_model_card as is_model_card -from fairseq2.models.factory import ( - DelegatingGenericModelFactory as DelegatingGenericModelFactory, -) -from fairseq2.models.factory import ModelFactory as ModelFactory from fairseq2.models.factory import create_model as create_model +from fairseq2.models.factory import model_factories as model_factories from fairseq2.models.loader import CheckpointConverter as CheckpointConverter from fairseq2.models.loader import DelegatingModelLoader as DelegatingModelLoader +from fairseq2.models.loader import ModelFactory as ModelFactory from fairseq2.models.loader import ModelLoader as ModelLoader from fairseq2.models.loader import StandardModelLoader as StandardModelLoader from fairseq2.models.loader import load_model as load_model diff --git a/src/fairseq2/models/chatbot.py b/src/fairseq2/models/chatbot.py deleted file mode 100644 index 6bdc3dcc9..000000000 --- a/src/fairseq2/models/chatbot.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from typing import Dict, Protocol - -from fairseq2.data.text import TextTokenizer -from fairseq2.generation import Chatbot, SequenceGenerator - - -class ChatbotFactory(Protocol): - """Constructs instances of :class:`Chatbot`.""" - - def __call__( - self, generator: SequenceGenerator, tokenizer: TextTokenizer - ) -> Chatbot: - """ - :param generator: - The sequence generator. - :param tokenizer: - The text tokenizer. - """ - - -class DelegatingChatbotFactory(ChatbotFactory): - """Constructs instances of :class:`Chatbot` using registered factories.""" - - _factories: Dict[str, ChatbotFactory] - - def __init__(self) -> None: - self._factories = {} - - def __call__( - self, generator: SequenceGenerator, tokenizer: TextTokenizer - ) -> Chatbot: - family = generator.model.family - if family is None: - raise ValueError("`generator.model.family` must not be `None`.") - - try: - factory = self._factories[family] - except KeyError: - raise ValueError( - f"`generator.model.family` must be a supported model family, but '{family}' has no registered chatbot." - ) from None - - return factory(generator, tokenizer) - - def register(self, family: str, factory: ChatbotFactory) -> None: - """Register a chatbot factory to use with this factory. - - :param family: - The model family supported by ``factory``. - :param factory: - The chatbot factory. - """ - if family in self._factories: - raise ValueError( - f"`family` must be a unique model family name, but '{family}' has already a registered chatbot." - ) - - self._factories[family] = factory - - -create_chatbot = DelegatingChatbotFactory() diff --git a/src/fairseq2/models/config_loader.py b/src/fairseq2/models/config_loader.py index 862e5c6b5..24d0ee05f 100644 --- a/src/fairseq2/models/config_loader.py +++ b/src/fairseq2/models/config_loader.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Protocol, Type, TypeVar, Union, final +from typing import Protocol, TypeVar, cast, final from fairseq2.assets import ( AssetCard, @@ -18,8 +18,13 @@ ) from fairseq2.config_registry import ConfigRegistry from fairseq2.typing import DataClass -from fairseq2.utils.dataclass import FieldError, update_dataclass -from fairseq2.utils.value_converter import ValueConverter, default_value_converter +from fairseq2.utils.dataclass import merge_dataclass +from fairseq2.utils.structured import ( + StructureError, + ValueConverter, + default_value_converter, + merge_unstructured, +) ModelConfigT = TypeVar("ModelConfigT", bound=DataClass) @@ -29,7 +34,9 @@ class ModelConfigLoader(Protocol[ModelConfigT_co]): """Loads model configurations of type ``ModelConfigT``.""" - def __call__(self, model_name_or_card: Union[str, AssetCard]) -> ModelConfigT_co: + def __call__( + self, model_name_or_card: str | AssetCard, unstructured_config: object = None + ) -> ModelConfigT_co: """ :param model_name_or_card: The name or the asset card of the model whole configuration to load. @@ -42,18 +49,17 @@ class StandardModelConfigLoader(ModelConfigLoader[ModelConfigT]): _asset_store: AssetStore _family: str - _config_kls: Type[ModelConfigT] - _arch_configs: Optional[ConfigRegistry[ModelConfigT]] - _value_converter: ValueConverter + _config_kls: type[ModelConfigT] + _arch_configs: ConfigRegistry[ModelConfigT] | None def __init__( self, - *, family: str, - config_kls: Type[ModelConfigT], - arch_configs: Optional[ConfigRegistry[ModelConfigT]], - asset_store: Optional[AssetStore] = None, - value_converter: Optional[ValueConverter] = None, + config_kls: type[ModelConfigT], + arch_configs: ConfigRegistry[ModelConfigT] | None, + *, + asset_store: AssetStore | None = None, + value_converter: ValueConverter | None = None, ) -> None: """ :param family: @@ -75,7 +81,9 @@ def __init__( self._arch_configs = arch_configs self._value_converter = value_converter or default_value_converter - def __call__(self, model_name_or_card: Union[str, AssetCard]) -> ModelConfigT: + def __call__( + self, model_name_or_card: str | AssetCard, unstructured_config: object = None + ) -> ModelConfigT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card else: @@ -84,9 +92,11 @@ def __call__(self, model_name_or_card: Union[str, AssetCard]) -> ModelConfigT: model_family = get_model_family(card) if model_family != self._family: raise AssetCardError( - f"The value of the field 'model_family' of the asset card '{card.name}' must be '{self._family}', but is '{model_family}' instead." + card.name, f"The value of the field 'model_family' of the asset card '{card.name}' must be '{self._family}', but is '{model_family}' instead." # fmt: skip ) + config_kls = self._config_kls + try: arch = card.field("model_arch").as_(str) except AssetCardFieldNotFoundError: @@ -95,7 +105,7 @@ def __call__(self, model_name_or_card: Union[str, AssetCard]) -> ModelConfigT: # Load the configuration. if arch is None: try: - config = self._config_kls() + base_config = config_kls() except TypeError as ex: raise AssetError( f"The '{self._family}' model family has no default configuration." @@ -107,27 +117,57 @@ def __call__(self, model_name_or_card: Union[str, AssetCard]) -> ModelConfigT: ) try: - config = self._arch_configs.get(arch) + base_config = self._arch_configs.get(arch) except ValueError: raise AssetError( f"The '{self._family}' model family has no architecture named '{arch}'." ) from None - # Check whether to override anything in the default configuration. - if config_overrides := card.field("model_config").get_as_(Dict[str, Any]): + # Override the default architecture configuration if needed. + model_config_fields = [] + + card_: AssetCard | None = card + + while card_ is not None: + if "model_config" in card_.metadata: + model_config_field = card_.field("model_config").as_unstructured() + + model_config_fields.append(model_config_field) + + card_ = card_.base + + if model_config_fields: try: - unknown_fields = update_dataclass( - config, config_overrides, value_converter=self._value_converter + unstructured_base_config = self._value_converter.unstructure( + base_config ) - except FieldError as ex: - raise AssetCardError( - f"The value of the field 'model_config' of the asset card '{card.name}' must be a valid model configuration, but the value of the configuration field '{ex.field_name}' is invalid. See nested exception for details." + except StructureError as ex: + raise AssetError( + f"The model configuration class of the '{self._family}' cannot be used. Please file a bug report to the model author." ) from ex - if unknown_fields: - raise AssetCardError( - f"The value of the field 'model_config' of the asset card '{card.name}' must be a valid model configuration, but the following configuration fields are unknown: {', '.join(unknown_fields)}" + try: + for model_config_field in reversed(model_config_fields): + unstructured_base_config = merge_unstructured( + unstructured_base_config, model_config_field + ) + + base_config = self._value_converter.structure( + unstructured_base_config, config_kls ) + except StructureError as ex: + raise AssetError( + f"The value of the field 'model_config' of the asset card '{card.name}' cannot be parsed as a valid model configuration. Please file a bug report to the asset author." + ) from ex + + if unstructured_config is None: + config = base_config + else: + config = self._value_converter.structure( + unstructured_config, config_kls, set_empty=True + ) + + config = merge_dataclass(base_config, config) return config @@ -140,14 +180,14 @@ def is_model_card(card: AssetCard) -> bool: def get_model_family(card: AssetCard) -> str: """Return the model family name contained in ``card``.""" try: - return card.field("model_family").as_(str) # type: ignore[no-any-return] + return cast(str, card.field("model_family").as_(str)) except AssetCardFieldNotFoundError: pass try: # Compatibility with older fairseq2 versions. - return card.field("model_type").as_(str) # type: ignore[no-any-return] + return cast(str, card.field("model_type").as_(str)) except AssetCardFieldNotFoundError: raise AssetCardFieldNotFoundError( - f"The asset card '{card.name}' must have a field named 'model_family." + card.name, f"The asset card '{card.name}' must have a field named 'model_family." # fmt: skip ) from None diff --git a/src/fairseq2/models/conformer/block.py b/src/fairseq2/models/conformer/block.py index 0ef951235..7d8916c9f 100644 --- a/src/fairseq2/models/conformer/block.py +++ b/src/fairseq2/models/conformer/block.py @@ -6,10 +6,11 @@ from __future__ import annotations -from typing import Optional, Tuple, final +from typing import final from torch import Tensor from torch.nn import Dropout +from typing_extensions import override from fairseq2.models.conformer.convolution import ConformerConvolution from fairseq2.nn import LayerNorm @@ -22,7 +23,7 @@ TransformerEncoderLayer, create_standard_layer_norm, ) -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -32,16 +33,16 @@ class ConformerBlock(TransformerEncoderLayer): ffn1_layer_norm: LayerNorm ffn1: FeedForwardNetwork - ffn1_dropout: Optional[Dropout] + ffn1_dropout: Dropout | None self_attn_layer_norm: LayerNorm self_attn: MultiheadAttention - self_attn_dropout: Optional[Dropout] + self_attn_dropout: Dropout | None conv_layer_norm: LayerNorm conv: ConformerConvolution - conv_dropout: Optional[Dropout] + conv_dropout: Dropout | None ffn2_layer_norm: LayerNorm ffn2: FeedForwardNetwork - ffn2_dropout: Optional[Dropout] + ffn2_dropout: Dropout | None layer_norm: LayerNorm def __init__( @@ -52,9 +53,9 @@ def __init__( ffn2: FeedForwardNetwork, *, dropout_p: float = 0.0, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param ffn1: @@ -122,9 +123,9 @@ def __init__( def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + ) -> tuple[Tensor, PaddingMask | None]: seqs = self._forward_ffn1(seqs) seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask) @@ -152,8 +153,8 @@ def _forward_ffn1(self, seqs: Tensor) -> Tensor: def _forward_self_attn( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask], + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None, ) -> Tensor: residual = seqs @@ -173,9 +174,7 @@ def _forward_self_attn( return seqs + residual - def _forward_conv( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tensor: + def _forward_conv(self, seqs: Tensor, padding_mask: PaddingMask | None) -> Tensor: residual = seqs seqs = self.conv_layer_norm(seqs) diff --git a/src/fairseq2/models/conformer/convolution.py b/src/fairseq2/models/conformer/convolution.py index faffbdc5c..0cda65f87 100644 --- a/src/fairseq2/models/conformer/convolution.py +++ b/src/fairseq2/models/conformer/convolution.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Literal, Optional, final +from typing import Literal, final from torch import Tensor from torch.nn import GLU, BatchNorm1d, Conv1d, Module, SiLU @@ -27,8 +27,8 @@ class ConformerConvolution(Module): pointwise_conv1_activation: GLU depthwise_conv: Conv1d causal_depthwise_conv: bool - batch_norm: Optional[BatchNorm1d] - layer_norm: Optional[LayerNorm] + batch_norm: BatchNorm1d | None + layer_norm: LayerNorm | None depthwise_activation: Module pointwise_conv2: Conv1d @@ -39,9 +39,9 @@ def __init__( *, causal_depthwise_conv: bool = False, norm_type: Literal["batch_norm", "layer_norm"] = "batch_norm", - depthwise_activation: Optional[Module] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + depthwise_activation: Module | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -115,7 +115,7 @@ def __init__( model_dim, model_dim, kernel_size=1, bias=False, device=device, dtype=dtype ) - def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor: + def forward(self, seqs: Tensor, padding_mask: PaddingMask | None) -> Tensor: """ :param seqs: The sequences to process. *Shape:* :math:`(N,S,M)`, where :math:`N` diff --git a/src/fairseq2/models/decoder.py b/src/fairseq2/models/decoder.py index 7fb7379c9..073ef3913 100644 --- a/src/fairseq2/models/decoder.py +++ b/src/fairseq2/models/decoder.py @@ -7,15 +7,14 @@ from __future__ import annotations from abc import abstractmethod -from typing import Optional, Tuple from torch import Tensor +from typing_extensions import override from fairseq2.data import VocabularyInfo from fairseq2.models.sequence import SequenceBatch, SequenceModel, SequenceModelOutput from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import override class DecoderModel(SequenceModel): @@ -50,10 +49,10 @@ def forward(self, batch: SequenceBatch) -> SequenceModelOutput: def decode( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: """Decode the specified sequences. :param seqs: @@ -77,7 +76,7 @@ def decode( @abstractmethod def project( - self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask] + self, decoder_output: Tensor, decoder_padding_mask: PaddingMask | None ) -> SequenceModelOutput: """Produce logits for next-step prediction. diff --git a/src/fairseq2/models/encoder_decoder.py b/src/fairseq2/models/encoder_decoder.py index bea0248d6..6c43ae771 100644 --- a/src/fairseq2/models/encoder_decoder.py +++ b/src/fairseq2/models/encoder_decoder.py @@ -7,16 +7,15 @@ from __future__ import annotations from abc import abstractmethod -from typing import Optional, Tuple from torch import Tensor +from typing_extensions import override from fairseq2.data import VocabularyInfo from fairseq2.models.seq2seq import Seq2SeqBatch, Seq2SeqModel from fairseq2.models.sequence import SequenceModelOutput from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import override class EncoderDecoderModel(Seq2SeqModel): @@ -56,8 +55,8 @@ def forward(self, batch: Seq2SeqBatch) -> SequenceModelOutput: @abstractmethod def encode( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: """Encode the specified source sequences. :param seqs: @@ -83,12 +82,12 @@ def encode( def decode( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + encoder_padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: """Decode the specified target sequences. :param seqs: @@ -123,7 +122,7 @@ def decode( @abstractmethod def project( - self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask] + self, decoder_output: Tensor, decoder_padding_mask: PaddingMask | None ) -> SequenceModelOutput: """Produce logits for next-step prediction. diff --git a/src/fairseq2/models/factory.py b/src/fairseq2/models/factory.py index a79f1dfc5..cfb17e053 100644 --- a/src/fairseq2/models/factory.py +++ b/src/fairseq2/models/factory.py @@ -6,247 +6,42 @@ from __future__ import annotations -from typing import ( - Any, - Dict, - Generic, - Mapping, - Optional, - Protocol, - Tuple, - Type, - TypeVar, - final, -) - +import torch +from mypy_extensions import DefaultNamedArg from torch.nn import Module -from fairseq2.config_registry import ConfigRegistry -from fairseq2.typing import DataClass, DataType, Device -from fairseq2.utils.dataclass import FieldError, update_dataclass -from fairseq2.utils.value_converter import ValueConverter, default_value_converter - -ModelT = TypeVar("ModelT", bound=Module) - -ModelT_co = TypeVar("ModelT_co", bound=Module, covariant=True) - -ModelConfigT = TypeVar("ModelConfigT", bound=DataClass) - -ModelConfigT_contra = TypeVar( - "ModelConfigT_contra", bound=DataClass, contravariant=True -) - - -class ModelFactory(Protocol[ModelConfigT_contra, ModelT_co]): - """Constructs models of type ``ModelT``.""" - - def __call__( - self, - config: ModelConfigT_contra, - *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, - ) -> ModelT_co: - """ - :param config: - The model configuration. - :param device: - The device on which to initialize the model. - :param dtype: - The data type of the model parameters and buffers. - """ - - -class GenericModelFactory(Protocol): - """Constructs models.""" - - def __call__( - self, - family: str, - arch: Optional[str], - config: Any, - *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, - ) -> Tuple[Module, DataClass]: - """ - :param family: - The family of the model. - :param arch: - The architecture of the model. If ``None``, uses the default model - configuration. - :param config: - The model configuration object or a dictionary where keys will - override corresponding fields in the model configuration of ``arch``. - """ - - -@final -class StandardGenericModelFactory(GenericModelFactory, Generic[ModelT, ModelConfigT]): - """Constructs models.""" - - _family: str - _factory: ModelFactory[ModelConfigT, ModelT] - _config_kls: Type[ModelConfigT] - _arch_configs: Optional[ConfigRegistry[ModelConfigT]] - _value_converter: ValueConverter - - def __init__( - self, - *, - family: str, - factory: ModelFactory[ModelConfigT, ModelT], - config_kls: Type[ModelConfigT], - arch_configs: Optional[ConfigRegistry[ModelConfigT]], - value_converter: Optional[ValueConverter] = None, - ) -> None: - """ - :param family: - The model family. - :param factory: - The factory to construct models. - :param config_kls: - The type of the model configuration. - :param arch_configs: - The registry containing all supported model architectures. - :param value_converter: - The :class:`ValueConverter` instance to use. If ``None``, the - default instance will be used. - """ - self._family = family - self._factory = factory - self._config_kls = config_kls - self._arch_configs = arch_configs - self._value_converter = value_converter or default_value_converter - - def __call__( - self, - family: str, - arch: Optional[str], - config: Any, - *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, - ) -> Tuple[Module, DataClass]: - if family != self._family: - raise ValueError( - f"`family` must be '{self._family}', but is '{family}' instead." - ) - - if isinstance(config, self._config_kls): - model = self._factory(config, device=device, dtype=dtype) - - return model, config - - if arch is None: - try: - config_ = self._config_kls() - except TypeError as ex: - raise RuntimeError( - f"The '{family}' model family has not default configuration." - ) from ex - else: - if self._arch_configs is None: - raise ValueError( - f"`arch` must be a registered architecture, but the '{family}' model family has no architecture named '{arch}'." - ) - - try: - config_ = self._arch_configs.get(arch) - except ValueError: - raise ValueError( - f"`arch` must be a registered architecture, but the '{family}' model family has no architecture named '{arch}'." - ) from None - - if config is not None: - if not isinstance(config, Mapping): - raise ValueError( - f"`config` must be of type `{self._config_kls}` or `{Mapping}`, but is of type `{type(config)}` instead." - ) - - try: - unknown_fields = update_dataclass( - config_, config, value_converter=self._value_converter - ) - except FieldError as ex: - raise ValueError( - f"`config` must be a valid model configuration, but the value of the configuration field '{ex.field_name}' is invalid. See nested exception for details." - ) from ex - - if unknown_fields: - raise ValueError( - f"`config` must be a valid model configuration, but the following configuration fields are unknown: {', '.join(unknown_fields)}" - ) - - model = self._factory(config_, device=device, dtype=dtype) - - return model, config_ - - -@final -class DelegatingGenericModelFactory(GenericModelFactory): - """Constructs models using registered factories.""" - - _factories: Dict[str, GenericModelFactory] - - def __init__(self) -> None: - self._factories = {} - - def __call__( - self, - family: str, - arch: Optional[str], - config: Any, - *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, - ) -> Tuple[Module, DataClass]: - try: - factory = self._factories[family] - except KeyError: - raise ValueError( - f"`family` must be a supported model family, but '{family}' has no registered factory." - ) from None - - return factory(family, arch, config, device=device, dtype=dtype) - - def register( - self, - *, - family: str, - factory: ModelFactory[ModelConfigT, ModelT], - config_kls: Type[ModelConfigT], - arch_configs: Optional[ConfigRegistry[ModelConfigT]], - value_converter: Optional[ValueConverter] = None, - ) -> None: - """Register a model factory. - - :param family: - The model family supported by ``factory``. - :param factory: - The factory to construct models. - :param config_kls: - The type of the model configuration. - :param arch_configs: - The registry containing all supported model architectures. - :param value_converter: - The :class:`ValueConverter` instance to use. If ``None``, the - default instance will be used. - """ - if family in self._factories: - raise ValueError( - f"`family` must be a unique model family name, but '{family}' has already a registered factory." - ) - - generic_factory = StandardGenericModelFactory( - family=family, - factory=factory, - config_kls=config_kls, - arch_configs=arch_configs, - value_converter=value_converter, - ) - - self._factories[family] = generic_factory - - -create_model = DelegatingGenericModelFactory() +from fairseq2.factory_registry import ConfigBoundFactoryRegistry +from fairseq2.typing import CPU, DataClass, DataType, Device + +model_factories = ConfigBoundFactoryRegistry[ + [DefaultNamedArg(Device, "device"), DefaultNamedArg(DataType, "dtype")], Module +]() + + +def create_model( + family: str, + arch: str | None = None, + unstructured_config: object = None, + *, + device: Device | None = None, + dtype: DataType | None = None, +) -> tuple[Module, DataClass]: + """Create a model of type registered with ``family``. + + :param family: + The family of the model. + :param arch: + The architecture of the model. + :param unstructured_config: + The (partial) configuration of the model. Any ``EMPTY`` field will be + filled with the corresponding value from the configuration of ``arch``. + + :returns: + - The model. + - The effective configuration of the model. + """ + factory = model_factories.get(family, unstructured_config, arch, set_empty=True) + + model = factory(device=device or CPU, dtype=dtype or torch.float32) + + return model, factory.config diff --git a/src/fairseq2/models/feature_extractor.py b/src/fairseq2/models/feature_extractor.py index 04e61a42f..877823e37 100644 --- a/src/fairseq2/models/feature_extractor.py +++ b/src/fairseq2/models/feature_extractor.py @@ -7,7 +7,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Tuple from torch import Tensor from torch.nn import Module @@ -31,8 +30,8 @@ def __init__(self, feature_dim: int) -> None: @abstractmethod def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: """ :param seqs: The sequences from which to extract features. *Shape:* @@ -44,9 +43,9 @@ def forward( is the batch size and :math:`S` is the sequence length. :returns: - - The extracted features. *Shape:* :math:`(N,S_{out},F)`, where + - The extracted features. *Shape:* :math:`(N,S_{out},E)`, where :math:`N` is the batch size, :math:`S_{out}` is the output - sequence length, and :math:`F` is the dimensionality of the + sequence length, and :math:`E` is the dimensionality of the features. - The padding mask of the extracted features. *Shape:* :math:`(N,S_{out})`, where :math:`N` is the batch size and diff --git a/src/fairseq2/models/fsdp.py b/src/fairseq2/models/fsdp.py index 5b130d9e3..d3d19290d 100644 --- a/src/fairseq2/models/fsdp.py +++ b/src/fairseq2/models/fsdp.py @@ -7,7 +7,7 @@ from __future__ import annotations from functools import partial -from typing import List, Literal, Optional, Set, Tuple, Type +from typing import Literal from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn import Module @@ -23,7 +23,7 @@ def get_fsdp_wrap_policy( model: Module, wrap_granularity: Literal["layer", "stack", "model"] = "layer" -) -> Tuple[Optional[FSDPWrapPolicy], Optional[List[Module]]]: +) -> tuple[FSDPWrapPolicy | None, list[Module] | None]: """Return the FSDP wrap policy for ``model`` along with ignored modules. :param model: @@ -38,7 +38,7 @@ def get_fsdp_wrap_policy( if wrap_granularity == "model": return None, None - kls: Set[Type[Module]] + kls: set[type[Module]] if wrap_granularity == "stack": kls = {TransformerEncoder, TransformerDecoder} diff --git a/src/fairseq2/models/jepa/__init__.py b/src/fairseq2/models/jepa/__init__.py new file mode 100644 index 000000000..79586c6bf --- /dev/null +++ b/src/fairseq2/models/jepa/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.jepa.factory import JEPA_FAMILY as JEPA_FAMILY +from fairseq2.models.jepa.factory import JepaBuilder as JepaBuilder +from fairseq2.models.jepa.factory import JepaConfig as JepaConfig +from fairseq2.models.jepa.factory import JepaEncoderBuilder as JepaEncoderBuilder +from fairseq2.models.jepa.factory import JepaEncoderConfig as JepaEncoderConfig +from fairseq2.models.jepa.factory import create_jepa_model as create_jepa_model +from fairseq2.models.jepa.factory import jepa_arch as jepa_arch +from fairseq2.models.jepa.factory import jepa_archs as jepa_archs +from fairseq2.models.jepa.loader import load_jepa_config as load_jepa_config +from fairseq2.models.jepa.loader import load_jepa_model as load_jepa_model + +# isort: split + +import fairseq2.models.jepa.archs # Register architectures diff --git a/src/fairseq2/models/jepa/archs.py b/src/fairseq2/models/jepa/archs.py new file mode 100644 index 000000000..d325f0a5f --- /dev/null +++ b/src/fairseq2/models/jepa/archs.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.jepa.factory import JepaConfig, jepa_arch + + +@jepa_arch("tiny") +def tiny() -> JepaConfig: + config = base() + + config.encoder_config.model_dim = 192 + config.encoder_config.num_encoder_attn_heads = 3 + + return config + + +@jepa_arch("small") +def small() -> JepaConfig: + config = base() + + config.encoder_config.model_dim = 384 + config.encoder_config.num_encoder_attn_heads = 6 + + return config + + +@jepa_arch("base") +def base() -> JepaConfig: + return JepaConfig() + + +@jepa_arch("large") +def large() -> JepaConfig: + config = base() + + config.encoder_config.model_dim = 1024 + config.encoder_config.num_encoder_layers = 24 + config.encoder_config.num_encoder_attn_heads = 16 + + return config + + +@jepa_arch("huge") +def huge() -> JepaConfig: + config = base() + + config.encoder_config.model_dim = 1280 + config.encoder_config.num_encoder_layers = 32 + config.encoder_config.num_encoder_attn_heads = 16 + + return config + + +@jepa_arch("giant") +def giant() -> JepaConfig: + config = base() + + config.encoder_config.model_dim = 1408 + config.encoder_config.num_encoder_layers = 40 + config.encoder_config.num_encoder_attn_heads = 16 + config.encoder_config.ffn_inner_dim_ratio = 48 / 11 + + return config + + +@jepa_arch("gigantic") +def gigantic() -> JepaConfig: + config = base() + + config.encoder_config.model_dim = 1664 + config.encoder_config.num_encoder_layers = 48 + config.encoder_config.num_encoder_attn_heads = 16 + config.encoder_config.ffn_inner_dim_ratio = 64 / 13 + + return config diff --git a/src/fairseq2/models/jepa/classifier/__init__.py b/src/fairseq2/models/jepa/classifier/__init__.py new file mode 100644 index 000000000..7f321afb0 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import fairseq2.models.jepa.classifier.archs # Register architectures +from fairseq2.models.jepa.classifier.factory import ( + JEPA_CLASSIFIER_FAMILY as JEPA_CLASSIFIER_FAMILY, +) +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierBuilder as JepaClassifierBuilder, +) +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierConfig as JepaClassifierConfig, +) +from fairseq2.models.jepa.classifier.factory import ( + create_jepa_classifier_model as create_jepa_classifier_model, +) +from fairseq2.models.jepa.classifier.factory import ( + jepa_classifier_archs as jepa_classifier_archs, +) +from fairseq2.models.jepa.classifier.model import ( + JepaClassifierModel as JepaClassifierModel, +) + +# isort: split diff --git a/src/fairseq2/models/jepa/classifier/archs.py b/src/fairseq2/models/jepa/classifier/archs.py new file mode 100644 index 000000000..b131e2342 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/archs.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.jepa.archs import base as jepa_base +from fairseq2.models.jepa.archs import huge as jepa_huge +from fairseq2.models.jepa.archs import large as jepa_large +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierConfig, + jepa_classifier_arch, +) + + +@jepa_classifier_arch("base") +def base() -> JepaClassifierConfig: + pretrain_config = jepa_base() + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) + + +@jepa_classifier_arch("large") +def large() -> JepaClassifierConfig: + pretrain_config = jepa_large() + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) + + +@jepa_classifier_arch("huge") +def huge() -> JepaClassifierConfig: + pretrain_config = jepa_huge() + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) diff --git a/src/fairseq2/models/jepa/classifier/factory.py b/src/fairseq2/models/jepa/classifier/factory.py new file mode 100644 index 000000000..bb7a6ea5d --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/factory.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import final + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.models.factory import model_factories +from fairseq2.models.jepa import JepaEncoderBuilder, JepaEncoderConfig +from fairseq2.models.jepa.classifier.model import ( + AttentivePooler, + CrossAttentionDecoderLayer, + JepaClassifierModel, +) +from fairseq2.nn.projection import IdentityProjection, Linear, Projection +from fairseq2.nn.transformer import ( + MultiheadAttention, + StandardMultiheadAttention, + create_default_sdpa, +) +from fairseq2.typing import DataType, Device + +JEPA_CLASSIFIER_FAMILY = "jepa_classifier" + + +@dataclass(kw_only=True) +class JepaClassifierConfig: + encoder_config: JepaEncoderConfig = field( + default_factory=lambda: JepaEncoderConfig() + ) + """The configuration of the vision encoder.""" + + pool_depth: int = 1 + """The pool depth (minimum 1 decoder layer)""" + + decoder_projection: bool = True + """If True, the decoder will have a linear layer on top""" + + num_queries: int = 1 + """Number of query tokens in the attention pool layer""" + + num_classes: int = 1000 + """Size of classification logits""" + + +jepa_classifier_archs = ConfigRegistry[JepaClassifierConfig]() + +jepa_classifier_arch = jepa_classifier_archs.decorator + + +@final +class JepaClassifierBuilder: + """Build a JEPA model fine-tuned for classification""" + + _config: JepaClassifierConfig + _encoder_builder: JepaEncoderBuilder + _device: Device | None + _dtype: DataType | None + + def __init__( + self, + config: JepaClassifierConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + self._config = config + + self._encoder_builder = JepaEncoderBuilder( + config.encoder_config, device=device, dtype=dtype + ) + + self._device, self._dtype = device, dtype + + def build_model(self) -> JepaClassifierModel: + encoder_frontend = self._encoder_builder.build_frontend() + encoder = self._encoder_builder.build_encoder() + pooler = self.build_pooler() + head = self.build_head() + + return JepaClassifierModel(encoder_frontend, encoder, pooler, head) + + def build_pooler(self) -> AttentivePooler: + config = self._config + + if config.pool_depth > 1: + encoder = self._encoder_builder.build_encoder(config.pool_depth) + else: + encoder = None + + decoder = self.build_decoder_layer() + + return AttentivePooler( + decoder=decoder, + encoder=encoder, + num_queries=config.num_queries, + init_std=config.encoder_config.init_std, + device=self._device, + dtype=self._dtype, + ) + + def build_head(self) -> Projection: + config = self._config + return Linear( + config.encoder_config.model_dim, + config.num_classes, + device=self._device, + dtype=self._dtype, + bias=True, + ) + + def build_decoder_layer(self) -> CrossAttentionDecoderLayer: + config = self._config + + cross_attn = self.build_cross_attention() + + ffn = self._encoder_builder.build_ffn(config.pool_depth) + + return CrossAttentionDecoderLayer( + cross_attn, + ffn, + layer_norm_factory=self._encoder_builder.build_layer_norm, + device=self._device, + dtype=self._dtype, + ) + + def build_cross_attention(self) -> MultiheadAttention: + config = self._config.encoder_config + + model_dim = config.model_dim + + sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) + + output_proj = self.build_cross_attn_output_projection() + + return StandardMultiheadAttention( + model_dim, + config.num_encoder_attn_heads, + sdpa=sdpa, + bias=config.qkv_bias, + output_proj=output_proj, + device=self._device, + dtype=self._dtype, + ) + + def build_cross_attn_output_projection(self) -> Projection: + config = self._config + + model_dim = config.encoder_config.model_dim + + if config.decoder_projection: + return Linear( + model_dim, + model_dim, + bias=True, + device=self._device, + dtype=self._dtype, + ) + else: + return IdentityProjection(model_dim, model_dim) + + +def create_jepa_classifier_model( + config: JepaClassifierConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, +) -> JepaClassifierModel: + return JepaClassifierBuilder( + config, + device=device, + dtype=dtype, + ).build_model() + + +model_factories.register( + JEPA_CLASSIFIER_FAMILY, + create_jepa_classifier_model, + JepaClassifierConfig, + jepa_classifier_archs, +) diff --git a/src/fairseq2/models/jepa/classifier/model.py b/src/fairseq2/models/jepa/classifier/model.py new file mode 100644 index 000000000..0ba5541d3 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/model.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import final + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Module, Parameter + +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.transformer import TransformerFrontend +from fairseq2.nn.normalization import LayerNorm +from fairseq2.nn.projection import Projection +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + LayerNormFactory, + MultiheadAttention, + TransformerEncoder, + create_standard_layer_norm, +) +from fairseq2.typing import DataType, Device + + +@final +class JepaClassifierModel(Module): + """ + Represents a pretrained Jepa model, with an attentive probing layer for + classfication tasks. See + * :cite:t:`https://doi.org/10.48550/arXiv.2301.08243` + * :cite:t:`https://doi.org/10.48550/arXiv.2404.08471` + """ + + model_dim: int + encoder_frontend: TransformerFrontend + encoder: TransformerEncoder + pooler: AttentivePooler + head: Projection + + def __init__( + self, + encoder_frontend: TransformerFrontend, + encoder: TransformerEncoder, + pooler: AttentivePooler, + head: Projection, + ) -> None: + super().__init__() + + self.model_dim = encoder.model_dim + + self.encoder_frontend = encoder_frontend + self.encoder = encoder + + self.pooler = pooler + + self.head = head + + def forward(self, batch: SequenceBatch) -> Tensor: + seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + + seqs, _ = self.encoder(seqs, padding_mask) + + seqs = self.pooler(seqs) + + # (N, P, M) + seqs = seqs.squeeze(1) # TODO: NEEDED? + + return self.head(seqs) # type: ignore[no-any-return] + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}" + + +@final +class AttentivePooler(Module): + """ + An attentive pooler that gets output of a Jepa encoder and decode it into + a logit of a given task. + + TODO: + - Move this into fairseq2.nn to benefit other similiar tasks. Internally, + this module is just a thin transformer encoder without self attention layer. + Optionally, it can consist of some extra transformer encoders depending on the + (finetuning) task + """ + + model_dim: int + decoder: CrossAttentionDecoderLayer + encoder: TransformerEncoder | None + query_tokens: Parameter + init_std: float + + def __init__( + self, + decoder: CrossAttentionDecoderLayer, + encoder: TransformerEncoder | None, + *, + num_queries: int = 1, + init_std: float = 0.02, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + super().__init__() + + self.model_dim = decoder.model_dim + + self.decoder = decoder + + if encoder: + self.encoder = encoder + else: + self.register_module("encoder", None) + + self.query_tokens = Parameter( + torch.empty((1, num_queries, self.model_dim), device=device, dtype=dtype) + ) + + self.init_std = init_std + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + nn.init.trunc_normal_(self.query_tokens, std=self.init_std) + + def forward(self, seqs: Tensor) -> Tensor: + if self.encoder is not None: + seqs, _ = self.encoder(seqs, padding_mask=None) + + batch_size = seqs.size(0) + + # (1, P, M) -> (N, P, M) + pool_seqs = self.query_tokens.repeat(batch_size, 1, 1) + + return self.decoder(pool_seqs, seqs) # type: ignore[no-any-return] + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}, num_queries={self.query_tokens.size(1)}" + + +@final +class CrossAttentionDecoderLayer(Module): + """Represents a simple transformer decoder with only cross attention and layernorm""" + + model_dim: int + cross_attn_layer_norm: LayerNorm + cross_attn: MultiheadAttention + ffn_layer_norm: LayerNorm + ffn: FeedForwardNetwork + + def __init__( + self, + cross_attn: MultiheadAttention, + ffn: FeedForwardNetwork, + *, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param cross_attn: + The encoder-decoder attention layer. + :param ffn: + The feed-forward network. + :param layer_norm_factory: + The factory to construct the Layer Normalization modules. + """ + super().__init__() + + model_dim = cross_attn.model_dim + + if layer_norm_factory is None: + layer_norm_factory = create_standard_layer_norm + + self.cross_attn_layer_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + + self.model_dim = model_dim + + self.cross_attn = cross_attn + + self.ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + self.ffn = ffn + + def forward(self, seqs: Tensor, encoder_output: Tensor) -> Tensor: + seqs = self._forward_cross_attn(seqs, encoder_output) + + seqs = self._forward_ffn(seqs) + + return seqs + + def _forward_cross_attn(self, seqs: Tensor, encoder_output: Tensor) -> Tensor: + residual = seqs + + # Note that the cross-attention norm is applied on encoder output and not seqs + encoder_output = self.cross_attn_layer_norm(encoder_output) + + seqs = self.cross_attn( + seqs, + padding_mask=None, + keys=encoder_output, + key_padding_mask=None, + values=encoder_output, + ) + + seqs = seqs + residual + + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + seqs = self.ffn_layer_norm(seqs) + + seqs = self.ffn(seqs) + + seqs = seqs + residual + + return seqs + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}" diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py new file mode 100644 index 000000000..b0eb74969 --- /dev/null +++ b/src/fairseq2/models/jepa/factory.py @@ -0,0 +1,420 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Final, cast + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import GELU, Conv2d, Conv3d + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.models.jepa.model import JepaModel +from fairseq2.models.transformer import TransformerFrontend +from fairseq2.models.vit import ( + Conv2dPatchFeatureExtractor, + Conv3dPatchFeatureExtractor, + PatchFeatureExtractor, + StandardViTFrontend, +) +from fairseq2.nn import ( + InterpolatedPositionEncoder, + LayerNorm, + Linear, + Sinusoidal2dPositionEncoder, + Sinusoidal3dPositionEncoder, + StandardLayerNorm, +) +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + MultiheadAttention, + StandardFeedForwardNetwork, + StandardMultiheadAttention, + StandardTransformerEncoder, + StandardTransformerEncoderLayer, + TransformerEncoder, + TransformerEncoderLayer, + TransformerNormOrder, + create_default_sdpa, +) +from fairseq2.nn.transformer.residual import DropPathResidualConnect +from fairseq2.typing import DataType, Device + +JEPA_FAMILY: Final = "jepa" + + +@dataclass(kw_only=True) +class JepaConfig: + """ + Holds the configuration of a JEPA model. + + The default values correspond to the 'base' JEPA architecture. + """ + + encoder_config: JepaEncoderConfig = field( + default_factory=lambda: JepaEncoderConfig() + ) + """The configuration of the Vision Transformer encoder.""" + + +@dataclass(kw_only=True) +class JepaEncoderConfig: + model_dim: int = 768 + """The dimensionality of the model.""" + + num_input_channels: int = 3 + """The number of input channels per frame.""" + + input_dims: tuple[int, ...] = (224, 224) + """ + The supported native dimensionality of inputs. Expected to be 2-dimensional + (height, width) for images and 3-dimensional (depth, height, width) for + videos. + """ + + patch_dims: tuple[int, ...] = (16, 16) + """The dimensionality of patches to be extracted from inputs.""" + + num_encoder_layers: int = 12 + """The number of encoder layers.""" + + num_encoder_attn_heads: int = 12 + """The number of attention heads in encoder layers.""" + + qkv_bias: bool = True + """ + If ``True``, query, key, and value projections in multi-head attention + layers will have an additive bias. + """ + + attn_dropout_p: float = 0.0 + """The dropout probability on attention weights.""" + + ffn_inner_dim_ratio: float = 4.0 + """ + The ratio of the dimensionality of the inner projection layers in + feed-forward networks to :attr:`model_dim`. + """ + + init_std: float = 0.02 + """ + The standard deviation to initialize weights and biases of projection and + normalization layers. + """ + + dropout_p: float = 0.0 + """The dropout probability on outputs of Transformer layers.""" + + droppath_p: float = 0.0 + """ + The probability of dropping sequences from outputs of multi-head attention + and feed-forward network layers before adding residuals. + """ + + uniform_power: bool = False + """ + If ``True``, each patch dimension will have equal representation in the + produced positional encodings. + """ + + +jepa_archs = ConfigRegistry[JepaConfig]() + +jepa_arch = jepa_archs.decorator + + +# TODO(balioglu): work in progress. Supports only vision encoder. +class JepaBuilder: + """Builds modules of a JEPA model.""" + + _config: JepaConfig + _encoder_builder: JepaEncoderBuilder + _device: Device | None + _dtype: DataType | None + + def __init__( + self, + config: JepaConfig, + encoder_builder: JepaEncoderBuilder | None = None, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + self._config = config + + if encoder_builder is None: + encoder_builder = JepaEncoderBuilder( + config.encoder_config, device=device, dtype=dtype + ) + + self._encoder_builder = encoder_builder + + self._device, self._dtype = device, dtype + + def build_model(self) -> JepaModel: + encoder_frontend = self._encoder_builder.build_frontend() + + encoder = self._encoder_builder.build_encoder() + + return JepaModel(encoder_frontend, encoder) + + +class JepaEncoderBuilder: + """Builds modules of a JEPA Vision Transformer encoder.""" + + _config: JepaEncoderConfig + _device: Device | None + _dtype: DataType | None + + def __init__( + self, + config: JepaEncoderConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + self._config = config + + self._device, self._dtype = device, dtype + + def build_frontend(self) -> TransformerFrontend: + config = self._config + + if len(config.input_dims) != len(config.patch_dims): + raise ValueError( + f"The lengths of `input_dims` and `patch_dims` must match, but they are {len(config.input_dims)} and {len(config.patch_dims)} instead." + ) + + feature_extractor = self.build_feature_extractor() + + pos_encoder = self.build_position_encoder() + + return StandardViTFrontend(feature_extractor, pos_encoder) + + def build_feature_extractor(self) -> PatchFeatureExtractor: + config = self._config + + init_std = config.init_std + + num_patch_dims = len(config.patch_dims) + + if num_patch_dims == 3: + patch_3d_dims = cast(tuple[int, int, int], config.patch_dims) + + def init_conv3d(conv: Conv3d) -> None: + init_truncated_normal(conv.weight, conv.bias, std=init_std) + + return Conv3dPatchFeatureExtractor( + config.num_input_channels, + config.model_dim, + patch_3d_dims, + init_fn=init_conv3d, + device=self._device, + dtype=self._dtype, + ) + elif num_patch_dims == 2: + patch_2d_dims = cast(tuple[int, int], config.patch_dims) + + def init_conv2d(conv: Conv2d) -> None: + init_truncated_normal(conv.weight, conv.bias, std=init_std) + + return Conv2dPatchFeatureExtractor( + config.num_input_channels, + config.model_dim, + patch_2d_dims, + init_fn=init_conv2d, + device=self._device, + dtype=self._dtype, + ) + else: + raise ValueError( + f"The length of `patch_dims` must be 2 or 3, but is {num_patch_dims} instead." + ) + + def build_position_encoder(self) -> InterpolatedPositionEncoder: + config = self._config + + num_input_dims = len(config.input_dims) + + if num_input_dims == 3: + input_3d_dims = cast(tuple[int, int, int], config.input_dims) + patch_3d_dims = cast(tuple[int, int, int], config.patch_dims) + + d_input_dim, h_input_dim, w_input_dim = input_3d_dims + d_patch_dim, h_patch_dim, w_patch_dim = patch_3d_dims + + grid_3d_dims = ( + (d_input_dim // d_patch_dim), + (h_input_dim // h_patch_dim), + (w_input_dim // w_patch_dim), + ) + + return Sinusoidal3dPositionEncoder( + config.model_dim, + grid_3d_dims, + uniform_power=config.uniform_power, + device=self._device, + ) + elif num_input_dims == 2: + input_2d_dims = cast(tuple[int, int], config.input_dims) + patch_2d_dims = cast(tuple[int, int], config.patch_dims) + + h_input_dim, w_input_dim = input_2d_dims + h_patch_dim, w_patch_dim = patch_2d_dims + + grid_2d_dims = (h_input_dim // h_patch_dim), (w_input_dim // w_patch_dim) + + return Sinusoidal2dPositionEncoder( + config.model_dim, grid_2d_dims, device=self._device + ) + else: + raise ValueError( + f"The length of `input_dims` must be 2 or 3, but is {num_input_dims} instead." + ) + + def build_encoder(self, num_layers: int | None = None) -> TransformerEncoder: + config = self._config + + if num_layers is None: + num_layers = config.num_encoder_layers + + layers = [self.build_encoder_layer(i) for i in range(num_layers)] + + return StandardTransformerEncoder( + layers, + norm_order=TransformerNormOrder.PRE, + layer_norm_factory=self.build_layer_norm, + device=self._device, + dtype=self._dtype, + ) + + def build_encoder_layer(self, layer_idx: int) -> TransformerEncoderLayer: + config = self._config + + self_attn = self.build_attention(layer_idx) + + ffn = self.build_ffn(layer_idx) + + drop_path = DropPathResidualConnect(drop_p=config.droppath_p) + + return StandardTransformerEncoderLayer( + self_attn, + ffn, + dropout_p=config.dropout_p, + norm_order=TransformerNormOrder.PRE, + layer_norm_factory=self.build_layer_norm, + self_attn_residual=drop_path, + ffn_residual=drop_path, + device=self._device, + dtype=self._dtype, + ) + + def build_attention(self, layer_idx: int) -> MultiheadAttention: + config = self._config + + sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) + + output_proj = self.build_mha_output_projection(layer_idx) + + return StandardMultiheadAttention( + config.model_dim, + config.num_encoder_attn_heads, + sdpa=sdpa, + bias=config.qkv_bias, + output_proj=output_proj, + device=self._device, + dtype=self._dtype, + ) + + def build_mha_output_projection(self, layer_idx: int) -> Linear: + config = self._config + + init_std = config.init_std + + def init_projection(proj: Linear) -> None: + init_truncated_normal(proj.weight, proj.bias, std=init_std) + + with torch.no_grad(): + proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) + + return Linear( + config.model_dim, + config.model_dim, + bias=True, + init_fn=init_projection, + device=self._device, + dtype=self._dtype, + ) + + def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: + config = self._config + + init_std = config.init_std + + def init_projection(proj: Linear) -> None: + init_truncated_normal(proj.weight, proj.bias, std=init_std) + + with torch.no_grad(): + proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) + + inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) + + return StandardFeedForwardNetwork( + config.model_dim, + inner_dim, + bias=True, + inner_activation=GELU(), + proj_init_fn=init_projection, + norm_order=TransformerNormOrder.PRE, + device=self._device, + dtype=self._dtype, + ) + + def build_layer_norm( + self, + model_dim: int, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> LayerNorm: + config = self._config + + init_std = config.init_std + + def init_layer_norm(m: LayerNorm) -> None: + if m.weight is not None: + init_truncated_normal(m.weight, m.bias, std=init_std) + + return StandardLayerNorm( + model_dim, + bias=True, + eps=1e-6, + init_fn=init_layer_norm, + device=device, + dtype=dtype, + ) + + +def init_truncated_normal( + weight: Tensor, bias: Tensor | None, *, std: float = 1.0 +) -> None: + nn.init.trunc_normal_(weight, std=std) + + if bias is not None: + nn.init.zeros_(bias) + + +def create_jepa_model( + config: JepaConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, +) -> JepaModel: + return JepaBuilder(config, device=device, dtype=dtype).build_model() diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py new file mode 100644 index 000000000..a38c613e7 --- /dev/null +++ b/src/fairseq2/models/jepa/loader.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any + +import torch + +from fairseq2.models.config_loader import StandardModelConfigLoader +from fairseq2.models.jepa.factory import ( + JEPA_FAMILY, + JepaConfig, + create_jepa_model, + jepa_archs, +) +from fairseq2.models.loader import StandardModelLoader +from fairseq2.models.utils.checkpoint import convert_model_state_dict + +load_jepa_config = StandardModelConfigLoader(JEPA_FAMILY, JepaConfig, jepa_archs) + + +def convert_jepa_checkpoint( + checkpoint: dict[str, Any], config: JepaConfig +) -> dict[str, Any]: + # We have a shared checkpoint, used for other use cases (frozen evaluation,..) + if "target_encoder" in checkpoint: + return convert_jepa_encoder_checkpoint( + checkpoint["target_encoder"], config=config + ) + + if "encoder" in checkpoint: + return convert_jepa_encoder_checkpoint(checkpoint["encoder"], config=config) + + raise ValueError(f"encoder not found (available keys: {checkpoint.keys()})") + + +def convert_jepa_encoder_checkpoint( + checkpoint: dict[str, Any], config: JepaConfig +) -> dict[str, Any]: + del checkpoint["module.backbone.pos_embed"] + + new_checkpoint = {} + + for name, param in checkpoint.items(): + if name.endswith("qkv.weight"): + q_proj, k_proj, v_proj = torch.chunk(param, 3, dim=0) + + new_checkpoint[name[:-10] + "q_proj.weight"] = q_proj + new_checkpoint[name[:-10] + "k_proj.weight"] = k_proj + new_checkpoint[name[:-10] + "v_proj.weight"] = v_proj + + continue + + if name.endswith("qkv.bias"): + q_bias, k_bias, v_bias = torch.chunk(param, 3, dim=0) + + new_checkpoint[name[:-8] + "q_proj.bias"] = q_bias + new_checkpoint[name[:-8] + "k_proj.bias"] = k_bias + new_checkpoint[name[:-8] + "v_proj.bias"] = v_bias + + continue + + new_checkpoint[name] = param + + key_map = { + # fmt: off + r"^module\.backbone\.blocks\.([0-9]+)\.attn\.q_proj\.": r"encoder.layers.\1.self_attn.q_proj.", + r"^module\.backbone\.blocks\.([0-9]+)\.attn\.k_proj\.": r"encoder.layers.\1.self_attn.k_proj.", + r"^module\.backbone\.blocks\.([0-9]+)\.attn\.v_proj\.": r"encoder.layers.\1.self_attn.v_proj.", + r"^module\.backbone\.blocks\.([0-9]+)\.attn\.proj\.": r"encoder.layers.\1.self_attn.output_proj.", + r"^module\.backbone\.blocks\.([0-9]+)\.norm1\.": r"encoder.layers.\1.self_attn_layer_norm.", + r"^module\.backbone\.blocks\.([0-9]+)\.mlp\.fc1\.": r"encoder.layers.\1.ffn.inner_proj.", + r"^module\.backbone\.blocks\.([0-9]+)\.mlp\.fc2\.": r"encoder.layers.\1.ffn.output_proj.", + r"^module\.backbone\.blocks\.([0-9]+)\.norm2\.": r"encoder.layers.\1.ffn_layer_norm.", + r"^module\.backbone\.norm\.": r"encoder.layer_norm.", + r"^module\.backbone\.patch_embed\.proj\.": r"encoder_frontend.feature_extractor.conv.", + # fmt: on + } + + checkpoint = convert_model_state_dict(new_checkpoint, key_map) + + return {"model": checkpoint} + + +load_jepa_model = StandardModelLoader( + config_loader=load_jepa_config, + factory=create_jepa_model, + checkpoint_converter=convert_jepa_checkpoint, +) diff --git a/src/fairseq2/models/jepa/model.py b/src/fairseq2/models/jepa/model.py new file mode 100644 index 000000000..b1413328a --- /dev/null +++ b/src/fairseq2/models/jepa/model.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import final + +from torch.nn import Module + +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.transformer import TransformerFrontend +from fairseq2.nn.transformer import TransformerEncoder + +# TODO(balioglu): This implementation is not complete. As of this commit, only +# the encoder and encoder-frontend are available for parity check purposes. + + +@final +class JepaModel(Module): + """ + Represents a JEPA model as described in: + * :cite:t:`https://doi.org/10.48550/arXiv.2301.08243` + * :cite:t:`https://doi.org/10.48550/arXiv.2404.08471` + """ + + model_dim: int + encoder_frontend: TransformerFrontend + encoder: TransformerEncoder + + def __init__( + self, + encoder_frontend: TransformerFrontend, + encoder: TransformerEncoder, + ) -> None: + super().__init__() + + self.model_dim = encoder.model_dim + + self.encoder_frontend = encoder_frontend + self.encoder = encoder + + def forward(self, batch: SequenceBatch) -> SequenceBatch: + seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + + seqs, padding_mask = self.encoder(seqs, padding_mask) + + return SequenceBatch(seqs, padding_mask) diff --git a/src/fairseq2/models/llama/__init__.py b/src/fairseq2/models/llama/__init__.py index eddf9ba1e..0ffff5245 100644 --- a/src/fairseq2/models/llama/__init__.py +++ b/src/fairseq2/models/llama/__init__.py @@ -6,9 +6,6 @@ from __future__ import annotations -from fairseq2.models.llama.chatbot import LLaMA3Chatbot as LLaMA3Chatbot -from fairseq2.models.llama.chatbot import LLaMAChatbot as LLaMAChatbot -from fairseq2.models.llama.chatbot import create_llama_chatbot as create_llama_chatbot from fairseq2.models.llama.factory import LLAMA_FAMILY as LLAMA_FAMILY from fairseq2.models.llama.factory import LLaMABuilder as LLaMABuilder from fairseq2.models.llama.factory import LLaMAConfig as LLaMAConfig @@ -18,8 +15,6 @@ from fairseq2.models.llama.factory import llama_archs as llama_archs from fairseq2.models.llama.loader import load_llama_config as load_llama_config from fairseq2.models.llama.loader import load_llama_model as load_llama_model -from fairseq2.models.llama.loader import load_llama_tokenizer as load_llama_tokenizer -from fairseq2.models.llama.tokenizer import LLaMA3Tokenizer as LLaMA3Tokenizer # isort: split diff --git a/src/fairseq2/models/llama/archs.py b/src/fairseq2/models/llama/archs.py index 103ae58b8..6856aec3f 100644 --- a/src/fairseq2/models/llama/archs.py +++ b/src/fairseq2/models/llama/archs.py @@ -7,7 +7,7 @@ from __future__ import annotations from fairseq2.data import VocabularyInfo -from fairseq2.models.llama.factory import LLaMAConfig, llama_arch +from fairseq2.models.llama.factory import LLaMAConfig, RopeScaling, llama_arch @llama_arch("7b") @@ -121,7 +121,7 @@ def _llama3_1_8b() -> LLaMAConfig: config = _llama3_8b() config.max_seq_len = 131_072 - config.use_scaled_rope = True + config.rope_scaling = RopeScaling() return config @@ -131,6 +131,36 @@ def _llama3_1_70b() -> LLaMAConfig: config = _llama3_70b() config.max_seq_len = 131_072 - config.use_scaled_rope = True + config.rope_scaling = RopeScaling() + + return config + + +@llama_arch("llama3_2_3b") +def _llama3_2_3b() -> LLaMAConfig: + config = _llama3_1_8b() + + config.model_dim = 3072 + config.ffn_inner_dim = int(3072 * 4 * 1.0) + config.ffn_inner_dim_to_multiple = 256 + config.num_attn_heads = 24 + config.num_key_value_heads = 8 + config.num_layers = 28 + config.rope_scaling = RopeScaling(factor=32.0) + + return config + + +@llama_arch("llama3_2_1b") +def _llama3_2_1b() -> LLaMAConfig: + config = _llama3_1_8b() + + config.model_dim = 2048 + config.ffn_inner_dim = int(2048 * 4 * 1.5) + config.ffn_inner_dim_to_multiple = 256 + config.num_attn_heads = 32 + config.num_key_value_heads = 8 + config.num_layers = 16 + config.rope_scaling = RopeScaling(factor=32.0) return config diff --git a/src/fairseq2/models/llama/factory.py b/src/fairseq2/models/llama/factory.py index 0188dffe0..cf8b383a3 100644 --- a/src/fairseq2/models/llama/factory.py +++ b/src/fairseq2/models/llama/factory.py @@ -6,16 +6,17 @@ from __future__ import annotations +import functools import math from dataclasses import dataclass, field -from typing import Final, Optional +from typing import Final, final import torch from torch import Tensor from fairseq2.config_registry import ConfigRegistry from fairseq2.data import VocabularyInfo -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.transformer import ( TransformerDecoderModel, TransformerEmbeddingFrontend, @@ -41,7 +42,7 @@ LLAMA_FAMILY: Final = "llama" -@dataclass +@dataclass(kw_only=True) class LLaMAConfig: """Holds the configuration of a LLaMA model. @@ -85,13 +86,35 @@ class LLaMAConfig: rope_theta: float = 10_000.0 """The coefficient of the long-term decay of the Rotary position encoder.""" - use_scaled_rope: bool = False - """If ``True``, scales Rotary encoding frequencies to LLaMA 3.1 context length.""" + rope_scaling: RopeScaling | None = None + """If specified, provides scaling parameters for RoPE frequencies, + aiming to increase the context length.""" dropout_p: float = 0.1 """The dropout probability on outputs of Transformer layers.""" +@final +@dataclass +class RopeScaling: + """Holds the configuration for RoPE (Rotary Position Embedding) + scaling in Llama 3 models. + """ + + factor: float = 8.0 + """Ratio between the intended max context length and the model’s + original max context length.""" + + low_freq_factor: float = 1.0 + """Factor used to define low frequencies.""" + + high_freq_factor: float = 4.0 + """Factor used to define high frequencies.""" + + original_context_length: int = 8192 + """Original context length. Defaults to LLaMA 3's context length.""" + + llama_archs = ConfigRegistry[LLaMAConfig]() llama_arch = llama_archs.decorator @@ -107,16 +130,16 @@ class LLaMABuilder: """ _config: LLaMAConfig - _device: Optional[Device] - _dtype: Optional[DataType] - _pos_encoder: Optional[RotaryEncoder] + _device: Device | None + _dtype: DataType | None + _pos_encoder: RotaryEncoder | None def __init__( self, config: LLaMAConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -147,7 +170,7 @@ def build_model(self) -> TransformerDecoderModel: dtype=self._dtype, ) - return TransformerDecoderModel( + model = TransformerDecoderModel( decoder_frontend, decoder, final_proj, @@ -155,6 +178,10 @@ def build_model(self) -> TransformerDecoderModel: self._config.vocab_info, ) + model.set_family(LLAMA_FAMILY) + + return model + def build_decoder_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" embed = StandardEmbedding( @@ -213,8 +240,11 @@ def build_attention( sdpa = create_default_sdpa(attn_dropout_p=self._config.dropout_p) if self._pos_encoder is None: - if self._config.use_scaled_rope: - freqs_init_fn = self._init_scaled_freqs + if self._config.rope_scaling is not None: + freqs_init_fn = functools.partial( + self._init_scaled_freqs, + rope_scaling=self._config.rope_scaling, + ) else: freqs_init_fn = None @@ -254,15 +284,21 @@ def build_layer_norm( self, model_dim: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> LayerNorm: """Build a Layer Normalization module.""" return RMSNorm(model_dim, bias=False, device=device, dtype=dtype) @staticmethod - def _init_scaled_freqs(pos_encoder: RotaryEncoder) -> Tensor: + def _init_scaled_freqs( + pos_encoder: RotaryEncoder, rope_scaling: RopeScaling + ) -> Tensor: device = pos_encoder.freqs.device + scale_factor = rope_scaling.factor + l_freq_factor = rope_scaling.low_freq_factor + h_freq_factor = rope_scaling.high_freq_factor + old_context_len = rope_scaling.original_context_length # (E / 2) indices = torch.arange( @@ -274,13 +310,6 @@ def _init_scaled_freqs(pos_encoder: RotaryEncoder) -> Tensor: if device.type == "meta": return freqs # type: ignore[no-any-return] - old_context_len = 8192 # The context length of LLaMA 3. - - scale_factor = 8.0 - - l_freq_factor = 1 - h_freq_factor = 5 - l_freq_wavelen = old_context_len / l_freq_factor h_freq_wavelen = old_context_len / h_freq_factor @@ -288,6 +317,7 @@ def _init_scaled_freqs(pos_encoder: RotaryEncoder) -> Tensor: for freq in freqs.tolist(): wavelen = 2 * math.pi / freq + if wavelen < h_freq_wavelen: new_freqs.append(freq) elif wavelen > l_freq_wavelen: @@ -302,29 +332,14 @@ def _init_scaled_freqs(pos_encoder: RotaryEncoder) -> Tensor: def create_llama_model( config: LLaMAConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> TransformerDecoderModel: - """Create a LLaMA model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - model = LLaMABuilder(config, device=device, dtype=dtype).build_model() + """Create a LLaMA model.""" + return LLaMABuilder(config, device=device, dtype=dtype).build_model() - return model.set_family(LLAMA_FAMILY) - -create_model.register( - family=LLAMA_FAMILY, - factory=create_llama_model, - config_kls=LLaMAConfig, - arch_configs=llama_archs, -) +model_factories.register(LLAMA_FAMILY, create_llama_model, LLaMAConfig, llama_archs) def get_llama_lora_config() -> LoRAConfig: diff --git a/src/fairseq2/models/llama/integ.py b/src/fairseq2/models/llama/integ.py index 7a3a2ad4c..16ee8bbc0 100644 --- a/src/fairseq2/models/llama/integ.py +++ b/src/fairseq2/models/llama/integ.py @@ -6,12 +6,27 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any +from fairseq2.models.llama.factory import LLaMAConfig from fairseq2.models.utils.checkpoint import convert_model_state_dict -def convert_to_reference_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: +def get_ffn_dim_multipliers(architecture: str) -> float: + ffn_dim_multipliers = { + "llama2_70b": 1.3, + "llama3_8b": 1.3, + "llama3_70b": 1.3, + "llama3_1_8b": 1.3, + "llama3_1_70b": 1.3, + "llama3_1_405b": 1.2, + "llama3_2_1b": 1.5, + } + + return ffn_dim_multipliers.get(architecture, 1.0) + + +def convert_to_reference_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]: """Convert a fairseq2 LLaMA checkpoint to the reference format.""" try: model_key = checkpoint["model_key"] @@ -38,3 +53,50 @@ def convert_to_reference_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any } return convert_model_state_dict(state_dict, key_map) + + +def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]: + """Convert Llama's config to a dict mirroring Huggingface's format""" + + def compute_intermediate_size( + n: int, ffn_dim_multiplier: float = 1, multiple_of: int = 256 + ) -> int: + """From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171""" + return multiple_of * ( + (int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of + ) + + if config.rope_scaling is not None: + rope_scaling = { + "factor": config.rope_scaling.factor, + "low_freq_factor": config.rope_scaling.low_freq_factor, + "high_freq_factor": config.rope_scaling.high_freq_factor, + "original_max_position_embeddings": config.rope_scaling.original_context_length, + "rope_type": "llama3", + } + else: + rope_scaling = None + + # we only specify the parameters made explicit in the Huggingface converter + # https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384 + return { + "architectures": ["Fairseq2LlamaForCausalLM"], + "bos_token_id": config.vocab_info.bos_idx, + "eos_token_id": config.vocab_info.eos_idx, + "hidden_size": config.model_dim, + "intermediate_size": compute_intermediate_size( + config.model_dim, + get_ffn_dim_multipliers(arch), + config.ffn_inner_dim_to_multiple, + ), + "max_position_embeddings": config.max_seq_len, + "model_type": "llama", + "num_attention_heads": config.num_attn_heads, + "num_hidden_layers": config.num_layers, + "num_key_value_heads": config.num_key_value_heads, + "rms_norm_eps": 1e-5, + "rope_scaling": rope_scaling, + "rope_theta": config.rope_theta, + "tie_word_embeddings": False, + "vocab_size": config.vocab_info.size, + } diff --git a/src/fairseq2/models/llama/loader.py b/src/fairseq2/models/llama/loader.py index 7053c1956..4780088ba 100644 --- a/src/fairseq2/models/llama/loader.py +++ b/src/fairseq2/models/llama/loader.py @@ -6,16 +6,10 @@ from __future__ import annotations -from pathlib import Path -from typing import Any, Dict, final - -from fairseq2.assets import AssetCard -from fairseq2.data.text import ( - AbstractTextTokenizerLoader, - BasicSentencePieceTokenizer, - TextTokenizer, - load_text_tokenizer, -) +from typing import Any, Mapping + +from torch import Tensor + from fairseq2.gang import Gang from fairseq2.models.config_loader import StandardModelConfigLoader from fairseq2.models.llama.factory import ( @@ -24,93 +18,112 @@ create_llama_model, llama_archs, ) -from fairseq2.models.llama.tokenizer import LLaMA3Tokenizer from fairseq2.models.loader import StandardModelLoader, load_model from fairseq2.models.transformer import ( TransformerDecoderModel, shard_transformer_decoder_model, ) from fairseq2.models.utils.checkpoint import convert_model_state_dict -from fairseq2.typing import override - -load_llama_config = StandardModelConfigLoader( - family=LLAMA_FAMILY, config_kls=LLaMAConfig, arch_configs=llama_archs -) - -@final -class LLaMAModelLoader(StandardModelLoader[TransformerDecoderModel, LLaMAConfig]): - """Loads LLaMA models.""" - - @override - def _shard( - self, model: TransformerDecoderModel, gangs: Dict[str, Gang], card: AssetCard - ) -> None: - gang = gangs["tp"] # tensor parallel - - shard_embed_dim = card.field("shard_embed_dim").get_as_(bool, True) - - shard_transformer_decoder_model(model, gang, shard_embed_dim=shard_embed_dim) +load_llama_config = StandardModelConfigLoader(LLAMA_FAMILY, LLaMAConfig, llama_archs) def convert_llama_checkpoint( - checkpoint: Dict[str, Any], config: LLaMAConfig -) -> Dict[str, Any]: - """Convert a reference LLaMA checkpoint to fairseq2 format.""" + checkpoint: dict[str, Any], config: LLaMAConfig +) -> dict[str, Any]: + """Convert a reference or Hugging Face LLaMA checkpoint to fairseq2 format.""" # Check if we have a fairseq2 checkpoint. - if "output.weight" not in checkpoint: + if "model" in checkpoint: return checkpoint - key_map = { - # fmt: off - r"^layers\.([0-9]+)\.attention\.wq\.": r"decoder.layers.\1.self_attn.q_proj.", - r"^layers\.([0-9]+)\.attention\.wk\.": r"decoder.layers.\1.self_attn.k_proj.", - r"^layers\.([0-9]+)\.attention\.wv\.": r"decoder.layers.\1.self_attn.v_proj.", - r"^layers\.([0-9]+)\.attention\.wo\.": r"decoder.layers.\1.self_attn.output_proj.", - r"^layers\.([0-9]+)\.attention_norm\.": r"decoder.layers.\1.self_attn_layer_norm.", - r"^layers\.([0-9]+)\.feed_forward\.w1\.": r"decoder.layers.\1.ffn.gate_proj.", - r"^layers\.([0-9]+)\.feed_forward\.w2\.": r"decoder.layers.\1.ffn.output_proj.", - r"^layers\.([0-9]+)\.feed_forward\.w3\.": r"decoder.layers.\1.ffn.inner_proj.", - r"^layers\.([0-9]+)\.ffn_norm\.": r"decoder.layers.\1.ffn_layer_norm.", - r"^norm\.": r"decoder.layer_norm.", - r"^tok_embeddings\.": r"decoder_frontend.embed.", - r"^output\.": r"final_proj.", - # fmt: on - } - - # We do not need the pre-computed 'rope.freqs' buffers. - checkpoint = {k: v for (k, v) in checkpoint.items() if "rope.freqs" not in k} + # Check if we have a sharded checkpoint. + if "weights" in checkpoint: + checkpoint = checkpoint["weights"] + + # Check if we have a reference or Hugging Face checkpoint. + if "lm_head.weight" in checkpoint: # HG + head_dim = config.model_dim // config.num_attn_heads + + def permute_rotary(w: Tensor, num_heads: int) -> Tensor: + # (H, M) -> (H_d, 2, D / 2, M) + w = w.view(num_heads, 2, head_dim // 2, config.model_dim) + + # (H_d, 2, D / 2, M) -> (H_d, D / 2, 2, M) + w = w.transpose(1, 2) + + # (H_d, D / 2, 2, M) -> (H, M) + return w.reshape(-1, config.model_dim) + + for idx in range(config.num_layers): + q_key = f"model.layers.{idx}.self_attn.q_proj.weight" + k_key = f"model.layers.{idx}.self_attn.k_proj.weight" + + q_proj = checkpoint[q_key] + k_proj = checkpoint[k_key] + + q_proj = permute_rotary(q_proj, config.num_attn_heads) + k_proj = permute_rotary(k_proj, config.num_key_value_heads) + + checkpoint[q_key] = q_proj + checkpoint[k_key] = k_proj + + key_map = { + # fmt: off + r"^model\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"decoder.layers.\1.self_attn.q_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"decoder.layers.\1.self_attn.v_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.o_proj\.": r"decoder.layers.\1.self_attn.output_proj.", + r"^model\.layers\.([0-9]+)\.post_attention_layernorm\.": r"decoder.layers.\1.ffn_layer_norm.", + r"^model\.layers\.([0-9]+)\.mlp\.gate_proj\.": r"decoder.layers.\1.ffn.gate_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.down_proj\.": r"decoder.layers.\1.ffn.output_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.up_proj\.": r"decoder.layers.\1.ffn.inner_proj.", + r"^model\.layers\.([0-9]+)\.input_layernorm\.": r"decoder.layers.\1.self_attn_layer_norm.", + r"^model\.norm\.": r"decoder.layer_norm.", + r"^model\.embed_tokens\.": r"decoder_frontend.embed.", + r"^lm_head\.": r"final_proj.", + # fmt: on + } + else: + key_map = { + # fmt: off + r"^layers\.([0-9]+)\.attention\.wq\.": r"decoder.layers.\1.self_attn.q_proj.", + r"^layers\.([0-9]+)\.attention\.wk\.": r"decoder.layers.\1.self_attn.k_proj.", + r"^layers\.([0-9]+)\.attention\.wv\.": r"decoder.layers.\1.self_attn.v_proj.", + r"^layers\.([0-9]+)\.attention\.wo\.": r"decoder.layers.\1.self_attn.output_proj.", + r"^layers\.([0-9]+)\.attention_norm\.": r"decoder.layers.\1.self_attn_layer_norm.", + r"^layers\.([0-9]+)\.feed_forward\.w1\.": r"decoder.layers.\1.ffn.gate_proj.", + r"^layers\.([0-9]+)\.feed_forward\.w2\.": r"decoder.layers.\1.ffn.output_proj.", + r"^layers\.([0-9]+)\.feed_forward\.w3\.": r"decoder.layers.\1.ffn.inner_proj.", + r"^layers\.([0-9]+)\.ffn_norm\.": r"decoder.layers.\1.ffn_layer_norm.", + r"^norm\.": r"decoder.layer_norm.", + r"^tok_embeddings\.": r"decoder_frontend.embed.", + r"^output\.": r"final_proj.", + # fmt: on + } + + # We do not need the pre-computed 'rope.freqs' buffers. + checkpoint = {k: v for (k, v) in checkpoint.items() if "rope.freqs" not in k} checkpoint = convert_model_state_dict(checkpoint, key_map) return {"model": checkpoint} -load_llama_model = LLaMAModelLoader( +def shard_llama_model( + model: TransformerDecoderModel, config: LLaMAConfig, gangs: Mapping[str, Gang] +) -> None: + gang = gangs["tp"] # tensor parallel + + shard_embed_dim = config.max_seq_len < 8192 # LLaMA 1 or 2 + + shard_transformer_decoder_model(model, gang, shard_embed_dim=shard_embed_dim) + + +load_llama_model = StandardModelLoader( config_loader=load_llama_config, factory=create_llama_model, checkpoint_converter=convert_llama_checkpoint, + sharder=shard_llama_model, ) load_model.register(LLAMA_FAMILY, load_llama_model) - - -@final -class LLaMATokenizerLoader(AbstractTextTokenizerLoader[TextTokenizer]): - """Loads LLaMA tokenizers.""" - - @override - def _load(self, path: Path, card: AssetCard) -> TextTokenizer: - if card.field("use_v2_tokenizer").get_as_(bool, False): - f = card.field("model_config").field("vocab_info").field("eos_idx") - - eot_idx = 128_009 # end-of-turn - - return LLaMA3Tokenizer(path, instruct=f.get_as_(int) == eot_idx) - - return BasicSentencePieceTokenizer(path) - - -load_llama_tokenizer = LLaMATokenizerLoader() - -load_text_tokenizer.register(LLAMA_FAMILY, load_llama_tokenizer) diff --git a/src/fairseq2/models/llama/tokenizer.py b/src/fairseq2/models/llama/tokenizer.py deleted file mode 100644 index 8892ab008..000000000 --- a/src/fairseq2/models/llama/tokenizer.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from pathlib import Path -from typing import Final, Optional, final - -from fairseq2.data.text import TiktokenEncoder, TiktokenTokenizer -from fairseq2.typing import Device, override - - -@final -class LLaMA3Tokenizer(TiktokenTokenizer): - """Represents a LLaMA 3 tokenizer.""" - - _SPLIT_REGEX: Final = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # fmt: skip - - _eos_token: str - - def __init__(self, path: Path, instruct: bool = False) -> None: - """ - :param path: - The path to the tiktoken BPE file. - :param instruct: - If ``True``, uses EOT (end-of-turn) token in-place of EOS token. - """ - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end-of-message - "<|eot_id|>", # end-of-turn - "<|python_tag|>", - ] - - num_reserved_special_tokens = 256 - - for i in range(num_reserved_special_tokens - len(special_tokens)): - special_tokens.append(f"<|reserved_special_token_{2 + i}|>") - - self._eos_token = "<|eot_id|>" if instruct else "<|end_of_text|>" - - super().__init__( - path, - split_regex=self._SPLIT_REGEX, - unk_token=None, - bos_token="<|begin_of_text|>", - eos_token=self._eos_token, - pad_token="<|finetune_right_pad_id|>", - special_tokens=special_tokens, - ) - - @override - def create_encoder( - self, - *, - task: Optional[str] = None, - lang: Optional[str] = None, - mode: Optional[str] = None, - device: Optional[Device] = None, - pin_memory: bool = False, - ) -> TiktokenEncoder: - if task is not None: - raise ValueError(f"`task` must be `None`, but is '{task}' instead.") - - if lang is not None: - raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.") - - if mode is None or mode == "default": - prefix_tokens = ["<|begin_of_text|>"] - suffix_tokens = [self._eos_token] - elif mode == "prompt": - prefix_tokens = ["<|begin_of_text|>"] - # In prompt mode, we expect the generator to finish the sequence. - suffix_tokens = None - elif mode == "prompt_response": - prefix_tokens = [] - suffix_tokens = [self._eos_token] - else: - raise ValueError( - f"`mode` must be 'default' or 'prompt', but is '{mode}' instead." - ) - - return TiktokenEncoder( - self._encoding, - prefix_tokens=prefix_tokens, - suffix_tokens=suffix_tokens, - device=device, - pin_memory=pin_memory, - ) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 56ae92dfb..11b172781 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -7,7 +7,7 @@ from __future__ import annotations from pickle import PickleError -from typing import Any, Dict, Generic, Optional, Protocol, TypeVar, Union, final +from typing import Any, Generic, Mapping, Protocol, TypeVar, cast, final from torch.nn import Module from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present @@ -18,13 +18,12 @@ AssetDownloadManager, AssetError, AssetStore, + InProcAssetDownloadManager, default_asset_store, - default_download_manager, ) from fairseq2.gang import Gang from fairseq2.logging import get_log_writer from fairseq2.models.config_loader import ModelConfigLoader, get_model_family -from fairseq2.models.factory import ModelFactory from fairseq2.nn.utils.module import ( infer_device, load_state_dict, @@ -41,6 +40,8 @@ ModelT_co = TypeVar("ModelT_co", bound=Module, covariant=True) +ModelT_contra = TypeVar("ModelT_contra", bound=Module, contravariant=True) + ModelConfigT = TypeVar("ModelConfigT", bound=DataClass) ModelConfigT_contra = TypeVar( @@ -48,16 +49,37 @@ ) +class ModelFactory(Protocol[ModelConfigT_contra, ModelT_co]): + """Constructs models of type ``ModelT``.""" + + def __call__( + self, + config: ModelConfigT_contra, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> ModelT_co: + """ + :param config: + The model configuration. + :param device: + The device on which to initialize the model. + :param dtype: + The data type of the model parameters and buffers. + """ + + class ModelLoader(Protocol[ModelT_co]): """Loads models of type ``ModelT``.""" def __call__( self, - model_name_or_card: Union[str, AssetCard], + model_name_or_card: str | AssetCard, *, - gangs: Optional[Dict[str, Gang]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + gangs: Mapping[str, Gang] | None = None, + unstructured_config: object = None, + device: Device | None = None, + dtype: DataType | None = None, force: bool = False, progress: bool = True, ) -> ModelT_co: @@ -86,8 +108,8 @@ class CheckpointConverter(Protocol[ModelConfigT_contra]): """Converts checkpoints to fairseq2 format.""" def __call__( - self, checkpoint: Dict[str, Any], config: ModelConfigT_contra - ) -> Dict[str, Any]: + self, checkpoint: dict[str, Any], config: ModelConfigT_contra + ) -> dict[str, Any]: """ :param checkpoint: The checkpoint to convert. @@ -96,15 +118,27 @@ def __call__( """ +class ModelSharder(Protocol[ModelT_contra, ModelConfigT_contra]): + def __call__( + self, + model: ModelT_contra, + config: ModelConfigT_contra, + gangs: Mapping[str, Gang], + ) -> None: + ... + + +@final class StandardModelLoader(ModelLoader[ModelT], Generic[ModelT, ModelConfigT]): """Loads models of type ``ModelT``.""" _asset_store: AssetStore _download_manager: AssetDownloadManager _tensor_loader: TensorLoader - _checkpoint_converter: Optional[CheckpointConverter[ModelConfigT]] _config_loader: ModelConfigLoader[ModelConfigT] _factory: ModelFactory[ModelConfigT, ModelT] + _checkpoint_converter: CheckpointConverter[ModelConfigT] | None + _sharder: ModelSharder[ModelT, ModelConfigT] | None _restrict_checkpoints: bool _skip_meta_init: bool @@ -113,25 +147,19 @@ def __init__( *, config_loader: ModelConfigLoader[ModelConfigT], factory: ModelFactory[ModelConfigT, ModelT], + asset_store: AssetStore | None = None, + download_manager: AssetDownloadManager | None = None, + tensor_loader: TensorLoader | None = None, + checkpoint_converter: CheckpointConverter[ModelConfigT] | None = None, + sharder: ModelSharder[ModelT, ModelConfigT] | None = None, restrict_checkpoints: bool = True, skip_meta_init: bool = False, - asset_store: Optional[AssetStore] = None, - download_manager: Optional[AssetDownloadManager] = None, - tensor_loader: Optional[TensorLoader] = None, - checkpoint_converter: Optional[CheckpointConverter[ModelConfigT]] = None, ) -> None: """ :param config_loader: The configuration loader. :param factory: The factory to construct models. - :param restrict_checkpoints: - If ``True``, restricts the Python unpickler to load only tensors, - primitive types, and dictionaries. - :param skip_meta_init: - If ``True``, skips meta device initialization and constructs the - model directly on the requested device. Should be used with models - that do not support PyTorch's ``reset_parameters()`` convention. :param asset_store: The asset store where to check for available models. If ``None``, the default asset store will be used. @@ -143,24 +171,34 @@ def __init__( :param checkpoint_converter: The converter to which loaded checkpoints will be passed for further processing. + :param sharder: + The model sharder for tensor parallelism. + :param restrict_checkpoints: + If ``True``, restricts the Python unpickler to load only tensors, + primitive types, and dictionaries. + :param skip_meta_init: + If ``True``, skips meta device initialization and constructs the + model directly on the requested device. Should be used with models + that do not support PyTorch's ``reset_parameters()`` convention. """ self._asset_store = asset_store or default_asset_store - self._download_manager = download_manager or default_download_manager + self._download_manager = download_manager or InProcAssetDownloadManager() self._tensor_loader = tensor_loader or load_tensors - self._checkpoint_converter = checkpoint_converter self._config_loader = config_loader self._factory = factory + self._checkpoint_converter = checkpoint_converter + self._sharder = sharder self._restrict_checkpoints = restrict_checkpoints self._skip_meta_init = skip_meta_init - @final def __call__( self, - model_name_or_card: Union[str, AssetCard], + model_name_or_card: str | AssetCard, *, - gangs: Optional[Dict[str, Gang]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + gangs: Mapping[str, Gang] | None = None, + unstructured_config: object = None, + device: Device | None = None, + dtype: DataType | None = None, force: bool = False, progress: bool = True, ) -> ModelT: @@ -186,7 +224,7 @@ def __call__( num_shards = card.field("num_shards").get_as_(int, default=1) if num_shards < 1: raise AssetCardError( - f"The value of the field 'num_shards' of the asset card '{card.name}' must be greater than or equal to 1, but is {num_shards} instead." + card.name, f"The value of the field 'num_shards' of the asset card '{card.name}' must be greater than or equal to 1, but is {num_shards} instead." # fmt: skip ) if num_shards > 1: @@ -207,7 +245,7 @@ def __call__( model = None - config = self._config_loader(card) + config = self._config_loader(card, unstructured_config) if device.type == "meta": try: @@ -221,7 +259,14 @@ def __call__( ) from ex if gang is not None and gang.size > 1: - self._shard(model, gangs, card) # type: ignore[arg-type] + if self._sharder is None: + raise RuntimeError( + f"{card.name} has a sharded checkpoint, but has no model sharder. Please file a bug report to the model author." + ) + + assert gangs is not None + + self._sharder(model, config, gangs) return model @@ -240,7 +285,7 @@ def __call__( ) except ValueError as ex: raise AssetCardError( - f"The value of the field 'checkpoint' of the asset card '{card.name}' must be URI. See nested exception for details." + card.name, f"The value of the field 'checkpoint' of the asset card '{card.name}' must be URI. See nested exception for details." # fmt: skip ) from ex try: @@ -272,10 +317,17 @@ def __call__( model = self._factory(config, device=init_device, dtype=dtype) if gang is not None and gang.size > 1: - self._shard(model, gangs, card) # type: ignore[arg-type] + if self._sharder is None: + raise RuntimeError( + f"{card.name} has a sharded checkpoint, but has no model sharder. Please file a bug report to the model author." + ) + + assert gangs is not None + + self._sharder(model, config, gangs) try: - model_device = infer_device(model, name="model") + model_device = infer_device(model) except ValueError as ex: raise RuntimeError( "`factory` returned a model that is not constructed correctly. See nested exception for details." @@ -288,12 +340,12 @@ def __call__( # Load the model. try: - model_key = checkpoint["model_key"] + model_key = cast(str, checkpoint["model_key"]) except KeyError: model_key = "model" try: - state_dict = checkpoint[model_key] + state_dict = cast(dict[str, object], checkpoint[model_key]) except KeyError: raise AssetError( f"The checkpoint of {card.name} does not contain a '{model_key}' entry." @@ -316,20 +368,15 @@ def __call__( return model - def _shard(self, model: ModelT, gangs: Dict[str, Gang], card: AssetCard) -> None: - raise RuntimeError( - f"{card.name} has a sharded checkpoint, but has no model sharder. Please file a bug report to the model author." - ) - @final class DelegatingModelLoader(ModelLoader[ModelT]): """Loads models of type ``ModelT`` using registered loaders.""" _asset_store: AssetStore - _loaders: Dict[str, ModelLoader[ModelT]] + _loaders: dict[str, ModelLoader[ModelT]] - def __init__(self, *, asset_store: Optional[AssetStore] = None) -> None: + def __init__(self, *, asset_store: AssetStore | None = None) -> None: """ :param asset_store: The asset store where to check for available models. If ``None``, @@ -341,11 +388,12 @@ def __init__(self, *, asset_store: Optional[AssetStore] = None) -> None: def __call__( self, - model_name_or_card: Union[str, AssetCard], + model_name_or_card: str | AssetCard, *, - gangs: Optional[Dict[str, Gang]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + gangs: Mapping[str, Gang] | None = None, + unstructured_config: object = None, + device: Device | None = None, + dtype: DataType | None = None, force: bool = False, progress: bool = True, ) -> ModelT: @@ -360,12 +408,13 @@ def __call__( loader = self._loaders[family] except KeyError: raise AssetError( - f"The value of the field 'model_family' of the asset card '{card.name}' must be a supported model family, but '{family}' has no registered loader." + f"The value of the field 'model_family' of the asset card '{card.name}' must be a supported model family, but is '{family}' instead." ) from None return loader( model_name_or_card, gangs=gangs, + unstructured_config=unstructured_config, device=device, dtype=dtype, force=force, @@ -383,12 +432,12 @@ def register(self, family: str, loader: ModelLoader[ModelT]) -> None: """ if family in self._loaders: raise ValueError( - f"`family` must be a unique model family name, but '{family}' has already a registered loader." + f"`family` must be a unique model family name, but '{family}' is already registered." ) self._loaders[family] = loader - def supports(self, model_name_or_card: Union[str, AssetCard]) -> bool: + def supports(self, model_name_or_card: str | AssetCard) -> bool: """Return ``True`` if the specified model has a registered loader.""" if isinstance(model_name_or_card, AssetCard): card = model_name_or_card diff --git a/src/fairseq2/models/mistral/__init__.py b/src/fairseq2/models/mistral/__init__.py index a1a634983..4a1797ed8 100644 --- a/src/fairseq2/models/mistral/__init__.py +++ b/src/fairseq2/models/mistral/__init__.py @@ -6,7 +6,6 @@ from __future__ import annotations -from fairseq2.models.mistral.chatbot import MistralChatbot as MistralChatbot from fairseq2.models.mistral.factory import MISTRAL_FAMILY as MISTRAL_FAMILY from fairseq2.models.mistral.factory import MistralBuilder as MistralBuilder from fairseq2.models.mistral.factory import MistralConfig as MistralConfig @@ -15,9 +14,6 @@ from fairseq2.models.mistral.factory import mistral_archs as mistral_archs from fairseq2.models.mistral.loader import load_mistral_config as load_mistral_config from fairseq2.models.mistral.loader import load_mistral_model as load_mistral_model -from fairseq2.models.mistral.loader import ( - load_mistral_tokenizer as load_mistral_tokenizer, -) # isort: split diff --git a/src/fairseq2/models/mistral/factory.py b/src/fairseq2/models/mistral/factory.py index 5ffcde8f4..c071dbc34 100644 --- a/src/fairseq2/models/mistral/factory.py +++ b/src/fairseq2/models/mistral/factory.py @@ -7,11 +7,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Final, Optional +from typing import Final from fairseq2.config_registry import ConfigRegistry from fairseq2.data import VocabularyInfo -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.transformer import ( TransformerDecoderModel, TransformerEmbeddingFrontend, @@ -38,7 +38,7 @@ MISTRAL_FAMILY: Final = "mistral" -@dataclass +@dataclass(kw_only=True) class MistralConfig: """Holds the configuration of a Mistral model. @@ -92,16 +92,16 @@ class MistralBuilder: """ _config: MistralConfig - _device: Optional[Device] - _dtype: Optional[DataType] - _pos_encoder: Optional[RotaryEncoder] + _device: Device | None + _dtype: DataType | None + _pos_encoder: RotaryEncoder | None def __init__( self, config: MistralConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -132,7 +132,7 @@ def build_model(self) -> TransformerDecoderModel: dtype=self._dtype, ) - return TransformerDecoderModel( + model = TransformerDecoderModel( decoder_frontend, decoder, final_proj, @@ -140,6 +140,10 @@ def build_model(self) -> TransformerDecoderModel: self._config.vocab_info, ) + model.set_family(MISTRAL_FAMILY) + + return model + def build_decoder_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" embed = StandardEmbedding( @@ -238,8 +242,8 @@ def build_layer_norm( self, model_dim: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> LayerNorm: """Build a Layer Normalization module.""" return RMSNorm(model_dim, bias=False, device=device, dtype=dtype) @@ -248,26 +252,13 @@ def build_layer_norm( def create_mistral_model( config: MistralConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> TransformerDecoderModel: - """Create a Mistral model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - model = MistralBuilder(config, device=device, dtype=dtype).build_model() - - return model.set_family(MISTRAL_FAMILY) + """Create a Mistral model.""" + return MistralBuilder(config, device=device, dtype=dtype).build_model() -create_model.register( - family=MISTRAL_FAMILY, - factory=create_mistral_model, - config_kls=MistralConfig, - arch_configs=mistral_archs, +model_factories.register( + MISTRAL_FAMILY, create_mistral_model, MistralConfig, mistral_archs ) diff --git a/src/fairseq2/models/mistral/loader.py b/src/fairseq2/models/mistral/loader.py index 7de2a1267..013109535 100644 --- a/src/fairseq2/models/mistral/loader.py +++ b/src/fairseq2/models/mistral/loader.py @@ -6,12 +6,8 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any -from fairseq2.data.text import ( - default_basic_sentencepiece_tokenizer_loader, - load_text_tokenizer, -) from fairseq2.models.config_loader import StandardModelConfigLoader from fairseq2.models.loader import StandardModelLoader, load_model from fairseq2.models.mistral.factory import ( @@ -23,15 +19,16 @@ from fairseq2.models.utils.checkpoint import convert_model_state_dict load_mistral_config = StandardModelConfigLoader( - family=MISTRAL_FAMILY, config_kls=MistralConfig, arch_configs=mistral_archs + MISTRAL_FAMILY, MistralConfig, mistral_archs ) def convert_mistral_checkpoint( - checkpoint: Dict[str, Any], config: MistralConfig -) -> Dict[str, Any]: + checkpoint: dict[str, Any], config: MistralConfig +) -> dict[str, Any]: """Convert a reference Mistral checkpoint to fairseq2 format.""" - if "output.weight" not in checkpoint: + # Check if we have a fairseq2 checkpoint. + if "model" in checkpoint: return checkpoint key_map = { @@ -63,7 +60,3 @@ def convert_mistral_checkpoint( ) load_model.register(MISTRAL_FAMILY, load_mistral_model) - -load_mistral_tokenizer = default_basic_sentencepiece_tokenizer_loader - -load_text_tokenizer.register(MISTRAL_FAMILY, load_mistral_tokenizer) diff --git a/src/fairseq2/models/model.py b/src/fairseq2/models/model.py index db614bbb6..970f7ccc3 100644 --- a/src/fairseq2/models/model.py +++ b/src/fairseq2/models/model.py @@ -6,8 +6,6 @@ from __future__ import annotations -from typing import Optional - from torch.nn import Module from typing_extensions import Self @@ -15,7 +13,7 @@ class Model(Module): """Represents a machine learning model.""" - _family: Optional[str] + _family: str | None def __init__(self) -> None: super().__init__() @@ -24,16 +22,11 @@ def __init__(self) -> None: def set_family(self, family: str) -> Self: """Set the family of the model.""" - if self._family is not None: - raise ValueError( - f"The model must not have a prior family, but has already '{self._family}'." - ) - self._family = family return self @property - def family(self) -> Optional[str]: + def family(self) -> str | None: """The family of the model.""" return self._family diff --git a/src/fairseq2/models/nllb/__init__.py b/src/fairseq2/models/nllb/__init__.py index b5a0ac7ed..710ae6228 100644 --- a/src/fairseq2/models/nllb/__init__.py +++ b/src/fairseq2/models/nllb/__init__.py @@ -6,9 +6,6 @@ from __future__ import annotations -from fairseq2.models.nllb.loader import load_nllb_tokenizer as load_nllb_tokenizer -from fairseq2.models.nllb.tokenizer import NllbTokenizer as NllbTokenizer - # isort: split import fairseq2.models.nllb.archs # Register architectures. diff --git a/src/fairseq2/models/nllb/archs.py b/src/fairseq2/models/nllb/archs.py index 8134531dd..11750082d 100644 --- a/src/fairseq2/models/nllb/archs.py +++ b/src/fairseq2/models/nllb/archs.py @@ -22,6 +22,7 @@ def _dense_300m() -> TransformerConfig: config.num_encoder_layers = 6 config.num_decoder_layers = 6 config.ffn_inner_dim = 1024 * 4 + config.dropout_p = 0.3 return config diff --git a/src/fairseq2/models/nllb/loader.py b/src/fairseq2/models/nllb/loader.py deleted file mode 100644 index 2ec91117f..000000000 --- a/src/fairseq2/models/nllb/loader.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from pathlib import Path -from typing import List, final - -from fairseq2.assets import AssetCard -from fairseq2.data.text import AbstractTextTokenizerLoader, load_text_tokenizer -from fairseq2.models.nllb.tokenizer import NllbTokenizer -from fairseq2.typing import override - - -@final -class NllbTokenizerLoader(AbstractTextTokenizerLoader[NllbTokenizer]): - """Loads NLLB tokenizers.""" - - @override - def _load(self, path: Path, card: AssetCard) -> NllbTokenizer: - langs = card.field("langs").as_(List[str]) - - default_lang = card.field("default_lang").as_(str) - - return NllbTokenizer(path, langs, default_lang) - - -load_nllb_tokenizer = NllbTokenizerLoader() - -load_text_tokenizer.register("nllb", load_nllb_tokenizer) diff --git a/src/fairseq2/models/s2t_transformer/__init__.py b/src/fairseq2/models/s2t_transformer/__init__.py index 38ed516da..5408f5b30 100644 --- a/src/fairseq2/models/s2t_transformer/__init__.py +++ b/src/fairseq2/models/s2t_transformer/__init__.py @@ -36,12 +36,6 @@ from fairseq2.models.s2t_transformer.loader import ( load_s2t_transformer_model as load_s2t_transformer_model, ) -from fairseq2.models.s2t_transformer.loader import ( - load_s2t_transformer_tokenizer as load_s2t_transformer_tokenizer, -) -from fairseq2.models.s2t_transformer.tokenizer import ( - S2TTransformerTokenizer as S2TTransformerTokenizer, -) # isort: split diff --git a/src/fairseq2/models/s2t_transformer/factory.py b/src/fairseq2/models/s2t_transformer/factory.py index 8ad36b3e6..20f5b25e7 100644 --- a/src/fairseq2/models/s2t_transformer/factory.py +++ b/src/fairseq2/models/s2t_transformer/factory.py @@ -7,14 +7,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Final, Optional +from typing import Final from torch.nn import SiLU from fairseq2.config_registry import ConfigRegistry from fairseq2.data import VocabularyInfo from fairseq2.models.conformer import ConformerBlock, ConformerConvolution -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.s2t_transformer.feature_extractor import Conv1dFbankSubsampler from fairseq2.models.s2t_transformer.frontend import S2TTransformerFrontend from fairseq2.models.transformer import ( @@ -53,7 +53,7 @@ S2T_TRANSFORMER_FAMILY: Final = "s2t_transformer" -@dataclass +@dataclass(kw_only=True) class S2TTransformerConfig: """Holds the configuration of an S2T Transformer model. @@ -122,16 +122,16 @@ class S2TTransformerBuilder: """ _config: S2TTransformerConfig - _device: Optional[Device] - _dtype: Optional[DataType] - _rel_pos_encoding: Optional[RelativePositionalEncoding] + _device: Device | None + _dtype: DataType | None + _rel_pos_encoding: RelativePositionalEncoding | None def __init__( self, config: S2TTransformerConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -164,7 +164,7 @@ def build_model(self) -> TransformerModel: dtype=self._dtype, ) - return TransformerModel( + model = TransformerModel( encoder_frontend, encoder, decoder_frontend, @@ -174,6 +174,10 @@ def build_model(self) -> TransformerModel: self._config.target_vocab_info, ) + model.set_family(S2T_TRANSFORMER_FAMILY) + + return model + def build_encoder_frontend(self) -> TransformerFrontend: """Build a Transformer encoder front-end.""" feat_extractor = Conv1dFbankSubsampler( @@ -218,7 +222,7 @@ def build_decoder_frontend(self) -> TransformerFrontend: dtype=self._dtype, ) - def build_source_position_encoder(self) -> Optional[PositionEncoder]: + def build_source_position_encoder(self) -> PositionEncoder | None: """Build a position encoder for source sequences.""" if self._config.use_relative_pos: return None @@ -392,26 +396,16 @@ def build_ffn(self, use_swish: bool = False) -> FeedForwardNetwork: def create_s2t_transformer_model( config: S2TTransformerConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> TransformerModel: - """Create an S2T Transformer model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - model = S2TTransformerBuilder(config, device=device, dtype=dtype).build_model() - - return model.set_family(S2T_TRANSFORMER_FAMILY) + """Create an S2T Transformer model.""" + return S2TTransformerBuilder(config, device=device, dtype=dtype).build_model() -create_model.register( - family=S2T_TRANSFORMER_FAMILY, - factory=create_s2t_transformer_model, - config_kls=S2TTransformerConfig, - arch_configs=s2t_transformer_archs, +model_factories.register( + S2T_TRANSFORMER_FAMILY, + create_s2t_transformer_model, + S2TTransformerConfig, + s2t_transformer_archs, ) diff --git a/src/fairseq2/models/s2t_transformer/feature_extractor.py b/src/fairseq2/models/s2t_transformer/feature_extractor.py index 486e28df4..f610b2396 100644 --- a/src/fairseq2/models/s2t_transformer/feature_extractor.py +++ b/src/fairseq2/models/s2t_transformer/feature_extractor.py @@ -6,14 +6,16 @@ from __future__ import annotations -from typing import Final, Optional, Sequence, Tuple, final +from collections.abc import Sequence +from typing import Final, final from torch import Tensor from torch.nn import GLU, Conv1d, Sequential +from typing_extensions import override from fairseq2.models.feature_extractor import SequenceFeatureExtractor from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -34,9 +36,9 @@ def __init__( inner_dim: int, feature_dim: int, *, - kernel_sizes: Optional[Sequence[int]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + kernel_sizes: Sequence[int] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param num_channels: @@ -87,8 +89,8 @@ def __init__( @override def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: """See the base :meth:`SequenceFeatureExtractor.forward`. :param seqs: diff --git a/src/fairseq2/models/s2t_transformer/frontend.py b/src/fairseq2/models/s2t_transformer/frontend.py index c4bbf6266..90bb6dafb 100644 --- a/src/fairseq2/models/s2t_transformer/frontend.py +++ b/src/fairseq2/models/s2t_transformer/frontend.py @@ -7,17 +7,18 @@ from __future__ import annotations import math -from typing import Optional, Tuple, final +from typing import final from torch import Tensor from torch.nn import Dropout +from typing_extensions import override from fairseq2.models.feature_extractor import SequenceFeatureExtractor from fairseq2.models.transformer import TransformerFrontend from fairseq2.nn import Linear, PositionEncoder, Projection from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -25,22 +26,22 @@ class S2TTransformerFrontend(TransformerFrontend): """Represents a Transformer encoder front-end as described in Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1911.08460`.""" - feature_extractor: Optional[SequenceFeatureExtractor] + feature_extractor: SequenceFeatureExtractor | None scale: float - pos_encoder: Optional[PositionEncoder] - proj: Optional[Projection] - dropout: Optional[Dropout] + pos_encoder: PositionEncoder | None + proj: Projection | None + dropout: Dropout | None def __init__( self, model_dim: int, - feature_extractor: Optional[SequenceFeatureExtractor], - pos_encoder: Optional[PositionEncoder], + feature_extractor: SequenceFeatureExtractor | None, + pos_encoder: PositionEncoder | None, *, proj: bool = False, dropout_p: float = 0.0, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -97,10 +98,10 @@ def __init__( def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: if state_bag is not None: raise ValueError( "`S2TTransformerFrontend` does not support incremental decoding." diff --git a/src/fairseq2/models/s2t_transformer/loader.py b/src/fairseq2/models/s2t_transformer/loader.py index 2fc9754f6..f340b7c65 100644 --- a/src/fairseq2/models/s2t_transformer/loader.py +++ b/src/fairseq2/models/s2t_transformer/loader.py @@ -6,11 +6,8 @@ from __future__ import annotations -from pathlib import Path -from typing import Any, Dict, Final, List, final +from typing import Any -from fairseq2.assets import AssetCard -from fairseq2.data.text import AbstractTextTokenizerLoader, load_text_tokenizer from fairseq2.models.config_loader import StandardModelConfigLoader from fairseq2.models.loader import StandardModelLoader, load_model from fairseq2.models.s2t_transformer.factory import ( @@ -19,26 +16,23 @@ create_s2t_transformer_model, s2t_transformer_archs, ) -from fairseq2.models.s2t_transformer.tokenizer import S2TTransformerTokenizer from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint -from fairseq2.typing import override load_s2t_transformer_config = StandardModelConfigLoader( - family=S2T_TRANSFORMER_FAMILY, - config_kls=S2TTransformerConfig, - arch_configs=s2t_transformer_archs, + S2T_TRANSFORMER_FAMILY, S2TTransformerConfig, s2t_transformer_archs ) def convert_s2t_transformer_checkpoint( - checkpoint: Dict[str, Any], config: S2TTransformerConfig -) -> Dict[str, Any]: + checkpoint: dict[str, Any], config: S2TTransformerConfig +) -> dict[str, Any]: """Convert a fairseq S2T Transformer checkpoint to fairseq2 format.""" try: state_dict = checkpoint["model"] except KeyError: return checkpoint + # Check if we have a fairseq2 checkpoint. if "decoder.output_projection.weight" not in state_dict: return checkpoint @@ -97,27 +91,3 @@ def convert_s2t_transformer_checkpoint( ) load_model.register(S2T_TRANSFORMER_FAMILY, load_s2t_transformer_model) - - -@final -class S2TTransformerTokenizerLoader( - AbstractTextTokenizerLoader[S2TTransformerTokenizer] -): - """Loads S2T Transformer tokenizers.""" - - _VALID_TASKS: Final = {"translation", "transcription"} - - @override - def _load(self, path: Path, card: AssetCard) -> S2TTransformerTokenizer: - task = card.field("task").as_one_of(self._VALID_TASKS) - - target_langs = card.field("target_langs").as_(List[str]) - - return S2TTransformerTokenizer( - path, task, set(target_langs), default_target_lang=target_langs[0] - ) - - -load_s2t_transformer_tokenizer = S2TTransformerTokenizerLoader() - -load_text_tokenizer.register(S2T_TRANSFORMER_FAMILY, load_s2t_transformer_tokenizer) diff --git a/src/fairseq2/models/seq2seq.py b/src/fairseq2/models/seq2seq.py index c65033f65..2ea259f9f 100644 --- a/src/fairseq2/models/seq2seq.py +++ b/src/fairseq2/models/seq2seq.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Tuple, final +from typing import Any, final from torch import Tensor @@ -58,7 +58,7 @@ class Seq2SeqBatch: the batch size, :math:`S_{src}` is the source sequence length, and :math:`*` is any number of sequence-specific dimensions including none.""" - source_padding_mask: Optional[PaddingMask] + source_padding_mask: PaddingMask | None """The padding mask of :attr:`source_seqs`. *Shape:* :math:`(N,S_{src})`, where :math:`N` is the batch size and :math:`S_{src}` is the source sequence length.""" @@ -68,7 +68,7 @@ class Seq2SeqBatch: the batch size, :math:`S_{tgt}` is the target sequence length, and :math:`*` is any number of sequence-specific dimensions including none.""" - target_padding_mask: Optional[PaddingMask] + target_padding_mask: PaddingMask | None """The padding mask of :attr:`target_seqs`. *Shape:* :math:`(N,S_{tgt})`, where :math:`N` is the batch size and :math:`S_{tgt}` is the target sequence length.""" @@ -96,7 +96,7 @@ def num_target_elements(self) -> int: return int(self.target_padding_mask.seq_lens.sum()) -def as_auto_regressive_input(batch: Seq2SeqBatch) -> Tuple[Seq2SeqBatch, SequenceBatch]: +def as_auto_regressive_input(batch: Seq2SeqBatch) -> tuple[Seq2SeqBatch, SequenceBatch]: """Use ``batch`` to train an auto-regressive model. :returns: diff --git a/src/fairseq2/models/sequence.py b/src/fairseq2/models/sequence.py index 6545b8a53..38154e8a7 100644 --- a/src/fairseq2/models/sequence.py +++ b/src/fairseq2/models/sequence.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Tuple, final +from typing import Any, final import torch from torch import Tensor @@ -56,11 +56,11 @@ class SequenceBatch: size, :math:`S` is the sequence length, and :math:`*` is any number of sequence-specific dimensions including none.""" - padding_mask: Optional[PaddingMask] + padding_mask: PaddingMask | None """The padding mask of :attr:`seqs`. *Shape:* :math:`(N,S)`, where :math:`N` is the batch size and :math:`S` is the sequence length.""" - target_mask: Optional[Tensor] = None + target_mask: Tensor | None = None """The mask specifying the elements in ``seqs`` that should be treated as targets during model training or validation. *Shape:* :math:`(N,S)`, where :math:`N` is the batch size and :math:`S` is the sequence length.""" @@ -90,7 +90,7 @@ def num_target_elements(self) -> int: def as_auto_regressive_input( batch: SequenceBatch, -) -> Tuple[SequenceBatch, SequenceBatch]: +) -> tuple[SequenceBatch, SequenceBatch]: """Use ``batch`` to train an auto-regressive model. :returns: @@ -132,14 +132,14 @@ class SequenceModelOutput: :math:`N` is the batch size, :math:`S` is the sequence length, and :math:`T` is the size of the vocabulary.""" - pad_idx: Optional[int] + pad_idx: int | None """The index of the PAD symbols in the vocabulary.""" def compute_loss( self, targets: Tensor, *, - loss_mask: Optional[Tensor] = None, + loss_mask: Tensor | None = None, ignore_prefix_size: int = 0, label_smoothing: float = 0.0, ) -> Tensor: diff --git a/src/fairseq2/models/transformer/decoder_model.py b/src/fairseq2/models/transformer/decoder_model.py index cc6602051..0f5407aa2 100644 --- a/src/fairseq2/models/transformer/decoder_model.py +++ b/src/fairseq2/models/transformer/decoder_model.py @@ -6,9 +6,10 @@ from __future__ import annotations -from typing import Optional, Tuple, final +from typing import final from torch import Tensor +from typing_extensions import override from fairseq2.data import VocabularyInfo from fairseq2.gang import Gang @@ -35,7 +36,6 @@ StandardMultiheadAttention, TransformerDecoder, ) -from fairseq2.typing import override @final @@ -77,10 +77,10 @@ def __init__( def decode( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, PaddingMask]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask]: seqs, padding_mask = self.decoder_frontend( seqs, padding_mask, state_bag=state_bag ) @@ -93,7 +93,7 @@ def decode( @override def project( - self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask] + self, decoder_output: Tensor, decoder_padding_mask: PaddingMask | None ) -> SequenceModelOutput: logits = self.final_proj(decoder_output) diff --git a/src/fairseq2/models/transformer/factory.py b/src/fairseq2/models/transformer/factory.py index befab47b3..a5417d6b2 100644 --- a/src/fairseq2/models/transformer/factory.py +++ b/src/fairseq2/models/transformer/factory.py @@ -7,11 +7,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Final, Optional +from typing import Final from fairseq2.config_registry import ConfigRegistry from fairseq2.data import VocabularyInfo -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.transformer.frontend import ( TransformerEmbeddingFrontend, TransformerFrontend, @@ -45,7 +45,7 @@ TRANSFORMER_FAMILY: Final = "transformer" -@dataclass +@dataclass(kw_only=True) class TransformerConfig: """Holds the configuration of a Transformer model. @@ -102,15 +102,15 @@ class TransformerBuilder: """ _config: TransformerConfig - _device: Optional[Device] - _dtype: Optional[DataType] + _device: Device | None + _dtype: DataType | None def __init__( self, config: TransformerConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -135,7 +135,7 @@ def build_model(self) -> TransformerModel: final_proj = TiedProjection(embed.weight, bias=None) - return TransformerModel( + model = TransformerModel( frontend, encoder, frontend, @@ -145,6 +145,10 @@ def build_model(self) -> TransformerModel: self._config.vocab_info, ) + model.set_family(TRANSFORMER_FAMILY) + + return model + def build_embedding(self) -> StandardEmbedding: """Build an embedding table.""" return StandardEmbedding( @@ -259,26 +263,13 @@ def build_ffn(self) -> FeedForwardNetwork: def create_transformer_model( config: TransformerConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> TransformerModel: - """Create a Transformer model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - model = TransformerBuilder(config, device=device, dtype=dtype).build_model() - - return model.set_family(TRANSFORMER_FAMILY) + """Create a Transformer model.""" + return TransformerBuilder(config, device=device, dtype=dtype).build_model() -create_model.register( - family=TRANSFORMER_FAMILY, - factory=create_transformer_model, - config_kls=TransformerConfig, - arch_configs=transformer_archs, +model_factories.register( + TRANSFORMER_FAMILY, create_transformer_model, TransformerConfig, transformer_archs ) diff --git a/src/fairseq2/models/transformer/frontend.py b/src/fairseq2/models/transformer/frontend.py index 47e995910..518f9cd12 100644 --- a/src/fairseq2/models/transformer/frontend.py +++ b/src/fairseq2/models/transformer/frontend.py @@ -8,16 +8,17 @@ import math from abc import ABC, abstractmethod -from typing import Optional, Tuple, final +from typing import final from torch import Tensor from torch.nn import Dropout, Module +from typing_extensions import override from fairseq2.nn import Embedding, LayerNorm, PositionEncoder from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer import LayerNormFactory, create_standard_layer_norm -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device class TransformerFrontend(Module, ABC): @@ -38,10 +39,10 @@ def __init__(self, model_dim: int) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: """ :param seqs: The sequences to process. *Shape:* :math:`(N,S,*)`, where :math:`N` @@ -75,21 +76,21 @@ class TransformerEmbeddingFrontend(TransformerFrontend): embed: Embedding scale: float - pos_encoder: Optional[PositionEncoder] - layer_norm: Optional[LayerNorm] - dropout: Optional[Dropout] + pos_encoder: PositionEncoder | None + layer_norm: LayerNorm | None + dropout: Dropout | None def __init__( self, embed: Embedding, - pos_encoder: Optional[PositionEncoder], + pos_encoder: PositionEncoder | None, *, no_scale: bool = False, layer_norm: bool = False, dropout_p: float = 0.0, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param embed: @@ -142,10 +143,10 @@ def __init__( def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: embeds = self.embed(seqs) if self.scale != 1.0: diff --git a/src/fairseq2/models/transformer/loader.py b/src/fairseq2/models/transformer/loader.py index 16d55ede2..1899b3e58 100644 --- a/src/fairseq2/models/transformer/loader.py +++ b/src/fairseq2/models/transformer/loader.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import torch @@ -21,21 +21,20 @@ from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint load_transformer_config = StandardModelConfigLoader( - family=TRANSFORMER_FAMILY, - config_kls=TransformerConfig, - arch_configs=transformer_archs, + TRANSFORMER_FAMILY, TransformerConfig, transformer_archs ) def convert_transformer_checkpoint( - checkpoint: Dict[str, Any], config: TransformerConfig -) -> Dict[str, Any]: + checkpoint: dict[str, Any], config: TransformerConfig +) -> dict[str, Any]: """Convert a fairseq Transformer checkpoint to fairseq2 format.""" try: state_dict = checkpoint["model"] except KeyError: return checkpoint + # Check if we have a fairseq2 checkpoint. if "decoder.output_projection.weight" not in state_dict: return checkpoint diff --git a/src/fairseq2/models/transformer/model.py b/src/fairseq2/models/transformer/model.py index c28aebeb9..1b96613e1 100644 --- a/src/fairseq2/models/transformer/model.py +++ b/src/fairseq2/models/transformer/model.py @@ -6,10 +6,11 @@ from __future__ import annotations -from typing import Optional, Tuple, final +from typing import final import torch.nn as nn from torch import Tensor +from typing_extensions import override from fairseq2.data import VocabularyInfo from fairseq2.models.encoder_decoder import EncoderDecoderModel @@ -19,7 +20,6 @@ from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder -from fairseq2.typing import override @final @@ -71,8 +71,8 @@ def __init__( @override def encode( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: seqs, padding_mask = self.encoder_frontend(seqs, padding_mask) return self.encoder(seqs, padding_mask) # type: ignore[no-any-return] @@ -81,12 +81,12 @@ def encode( def decode( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, encoder_output: Tensor, - encoder_padding_mask: Optional[PaddingMask], + encoder_padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: seqs, padding_mask = self.decoder_frontend( seqs, padding_mask, state_bag=state_bag ) @@ -101,7 +101,7 @@ def decode( @override def project( - self, decoder_output: Tensor, decoder_padding_mask: Optional[PaddingMask] + self, decoder_output: Tensor, decoder_padding_mask: PaddingMask | None ) -> SequenceModelOutput: logits = self.final_proj(decoder_output) diff --git a/src/fairseq2/models/utils/checkpoint.py b/src/fairseq2/models/utils/checkpoint.py index 9530a1969..ddf69aa19 100644 --- a/src/fairseq2/models/utils/checkpoint.py +++ b/src/fairseq2/models/utils/checkpoint.py @@ -7,12 +7,13 @@ from __future__ import annotations import re -from typing import Any, Dict, Mapping +from collections.abc import Mapping +from typing import Any def convert_model_state_dict( - state_dict: Dict[str, Any], key_map: Mapping[str, str] -) -> Dict[str, Any]: + state_dict: dict[str, Any], key_map: Mapping[str, str] +) -> dict[str, Any]: """Convert a model state dictionary to fairseq2. :param state_dict: @@ -42,8 +43,8 @@ def get_new_key(old_key: str) -> str: def convert_fairseq_checkpoint( - checkpoint: Dict[str, Any], key_map: Mapping[str, str] -) -> Dict[str, Any]: + checkpoint: dict[str, Any], key_map: Mapping[str, str] +) -> dict[str, Any]: """Convert a fairseq checkpoint to fairseq2. :param checkpoint: diff --git a/src/fairseq2/models/vit/__init__.py b/src/fairseq2/models/vit/__init__.py new file mode 100644 index 000000000..d2a302cd2 --- /dev/null +++ b/src/fairseq2/models/vit/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.vit.feature_extractor import ( + Conv2dPatchFeatureExtractor as Conv2dPatchFeatureExtractor, +) +from fairseq2.models.vit.feature_extractor import ( + Conv3dPatchFeatureExtractor as Conv3dPatchFeatureExtractor, +) +from fairseq2.models.vit.feature_extractor import ( + PatchFeatureExtractor as PatchFeatureExtractor, +) +from fairseq2.models.vit.frontend import StandardViTFrontend as StandardViTFrontend diff --git a/src/fairseq2/models/vit/feature_extractor.py b/src/fairseq2/models/vit/feature_extractor.py new file mode 100644 index 000000000..50747eaaf --- /dev/null +++ b/src/fairseq2/models/vit/feature_extractor.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import final + +from torch import Tensor +from torch.nn import Conv2d, Conv3d, Module +from typing_extensions import override + +from fairseq2.typing import DataType, Device + + +class PatchFeatureExtractor(Module, ABC): + """ + Extracts patch features from N-dimensional inputs and embeds them in a + latent space. + """ + + feature_dim: int + + def __init__(self, feature_dim: int) -> None: + """ + :param feature_dim: + The dimensionality of extracted patch features. + """ + super().__init__() + + self.feature_dim = feature_dim + + @abstractmethod + def forward(self, x: Tensor) -> Tensor: + """ + :param x: The inputs from which to extract patch features. *Shape:* + :math:`(N,C,*)`, where :math:`N` is the batch size, :math:`C` is the + number of channels, and :math:`*` is any number of input-specific + dimensions. + + :returns: The extracted patch features. *Shape:* :math:`(N,*,E)`, where + :math:`N` is the batch size, :math:`*` is the same number of + dimensions as in input, but potentially with different + dimensionality, and :math:`E` is the dimensionality of the patch + features. + """ + + +@final +class Conv2dPatchFeatureExtractor(PatchFeatureExtractor): + """Extracts patch features from 2-dimensional inputs using convolution.""" + + conv: Conv2d + init_fn: Callable[[Conv2d], None] | None + + def __init__( + self, + num_channels: int, + feature_dim: int, + patch_dims: tuple[int, int], + *, + init_fn: Callable[[Conv2d], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param num_channels: The number of input channels. + :param feature_dim: The dimensionality of extracted patch features. + :param patch_dims: The dimensionality of height and width patch + dimensions. + """ + super().__init__(feature_dim) + + self.conv = Conv2d( + num_channels, + feature_dim, + kernel_size=patch_dims, + stride=patch_dims, + device=device, + dtype=dtype, + ) + + self.init_fn = init_fn + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + if self.init_fn is not None: + self.init_fn(self.conv) + else: + self.conv.reset_parameters() + + @override + def forward(self, x: Tensor) -> Tensor: + # (N, C, H_inp, W_inp) -> (N, H_out, W_out, E) + return self.conv(x).permute(0, 2, 3, 1) # type: ignore[no-any-return] + + +@final +class Conv3dPatchFeatureExtractor(PatchFeatureExtractor): + """Extracts patch features from 3-dimensional inputs using convolution.""" + + conv: Conv3d + init_fn: Callable[[Conv3d], None] | None + + def __init__( + self, + num_channels: int, + feature_dim: int, + patch_dims: tuple[int, int, int], + *, + init_fn: Callable[[Conv3d], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param num_channels: The number of input channels. + :param feature_dim: The dimensionality of extracted patch features. + :param patch_dims: The dimensionality of depth, height, and width patch + dimensions. + """ + super().__init__(feature_dim) + + self.conv = Conv3d( + num_channels, + feature_dim, + kernel_size=patch_dims, + stride=patch_dims, + device=device, + dtype=dtype, + ) + + self.init_fn = init_fn + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + if self.init_fn is not None: + self.init_fn(self.conv) + else: + self.conv.reset_parameters() + + @override + def forward(self, x: Tensor) -> Tensor: + # (N, C, D_inp, H_inp, W_inp) -> (N, D_out, H_out, W_out, E) + return self.conv(x).permute(0, 2, 3, 4, 1) # type: ignore[no-any-return] diff --git a/src/fairseq2/models/vit/frontend.py b/src/fairseq2/models/vit/frontend.py new file mode 100644 index 000000000..fb9185489 --- /dev/null +++ b/src/fairseq2/models/vit/frontend.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import final + +from torch import Tensor +from torch.nn import Dropout +from typing_extensions import override + +from fairseq2.models.transformer import TransformerFrontend +from fairseq2.models.vit.feature_extractor import PatchFeatureExtractor +from fairseq2.nn import InterpolatedPositionEncoder +from fairseq2.nn.incremental_state import IncrementalStateBag +from fairseq2.nn.padding import PaddingMask + + +@final +class StandardViTFrontend(TransformerFrontend): + """Represents a standard Vision Transformer front-end as described in + :cite:t:`https://doi.org/10.48550/arXiv.2010.11929`.""" + + feature_extractor: PatchFeatureExtractor + pos_encoder: InterpolatedPositionEncoder + dropout: Dropout | None + + def __init__( + self, + feature_extractor: PatchFeatureExtractor, + pos_encoder: InterpolatedPositionEncoder, + *, + dropout_p: float = 0.0, + ) -> None: + """ + :param feature_extractor: The feature extractor. + :param pos_encoder: The interpolated position encoder. + :param dropout_p: The dropout probability on extracted patch features. + """ + feature_dim = feature_extractor.feature_dim + + super().__init__(feature_dim) + + self.feature_extractor = feature_extractor + + if pos_encoder.encoding_dim != feature_dim: + raise ValueError( + f"`pos_encoder.encoding_dim` must be equal to `feature_extractor.feature_dim` ({feature_dim}), but is {pos_encoder.encoding_dim} instead." + ) + + self.pos_encoder = pos_encoder + + if dropout_p > 0.0: + self.dropout = Dropout(dropout_p) + else: + self.register_module("dropout", None) + + @override + def forward( + self, + seqs: Tensor, + padding_mask: PaddingMask | None, + *, + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: + if padding_mask is not None: + raise ValueError(f"`{type(self)}` does not support padding mask.") + + if state_bag is not None: + raise ValueError(f"`{type(self)}` does not support incremental decoding.") + + seqs = self.feature_extractor(seqs) + + seqs = self.pos_encoder(seqs) + + # (N, *, E) -> (N, S, E) + seqs = seqs.flatten(1, -2) + + return seqs, None diff --git a/src/fairseq2/models/w2vbert/factory.py b/src/fairseq2/models/w2vbert/factory.py index 52934de8e..df4f788d2 100644 --- a/src/fairseq2/models/w2vbert/factory.py +++ b/src/fairseq2/models/w2vbert/factory.py @@ -7,15 +7,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Final, Optional +from typing import Final from fairseq2.config_registry import ConfigRegistry -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.w2vbert.model import W2VBertModel from fairseq2.models.wav2vec2 import ( Wav2Vec2Builder, Wav2Vec2Config, - Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig, ) from fairseq2.nn.transformer import TransformerNormOrder @@ -24,7 +23,7 @@ W2VBERT_FAMILY: Final = "w2vbert" -@dataclass +@dataclass(kw_only=True) class W2VBertConfig: """Holds the configuration of a w2v-BERT model. @@ -102,16 +101,16 @@ class W2VBertBuilder: _config: W2VBertConfig _w2v2_builder: Wav2Vec2Builder - _device: Optional[Device] - _dtype: Optional[DataType] + _device: Device | None + _dtype: DataType | None def __init__( self, config: W2VBertConfig, - w2v2_builder: Wav2Vec2Builder, + w2v2_builder: Wav2Vec2Builder | None = None, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -140,6 +139,11 @@ def __init__( self._config = config + if w2v2_builder is None: + w2v2_builder = Wav2Vec2Builder( + config.w2v2_config, device=device, dtype=dtype + ) + self._w2v2_builder = w2v2_builder self._device, self._dtype = device, dtype @@ -148,7 +152,7 @@ def build_model(self) -> W2VBertModel: """Build a model.""" w2v2_model = self._w2v2_builder.build_model() - return W2VBertModel( + model = W2VBertModel( w2v2_model, self._config.num_bert_encoder_layers, num_target_codebooks=self._config.num_target_codebooks, @@ -156,38 +160,21 @@ def build_model(self) -> W2VBertModel: dtype=self._dtype, ) + model.set_family(W2VBERT_FAMILY) + + return model + def create_w2vbert_model( config: W2VBertConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> W2VBertModel: - """Create a w2v-BERT model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - encoder_builder = Wav2Vec2EncoderBuilder( - config.w2v2_config.encoder_config, device=device, dtype=dtype - ) - - w2v2_builder = Wav2Vec2Builder( - config.w2v2_config, encoder_builder, device=device, dtype=dtype - ) - - builder = W2VBertBuilder(config, w2v2_builder, device=device, dtype=dtype) - - return builder.build_model().set_family(W2VBERT_FAMILY) + """Create a w2v-BERT model.""" + return W2VBertBuilder(config, device=device, dtype=dtype).build_model() -create_model.register( - family=W2VBERT_FAMILY, - factory=create_w2vbert_model, - config_kls=W2VBertConfig, - arch_configs=w2vbert_archs, +model_factories.register( + W2VBERT_FAMILY, create_w2vbert_model, W2VBertConfig, w2vbert_archs ) diff --git a/src/fairseq2/models/w2vbert/loader.py b/src/fairseq2/models/w2vbert/loader.py index 0b00e710e..b98ca87bb 100644 --- a/src/fairseq2/models/w2vbert/loader.py +++ b/src/fairseq2/models/w2vbert/loader.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import torch @@ -21,13 +21,13 @@ ) load_w2vbert_config = StandardModelConfigLoader( - family=W2VBERT_FAMILY, config_kls=W2VBertConfig, arch_configs=w2vbert_archs + W2VBERT_FAMILY, W2VBertConfig, w2vbert_archs ) def convert_w2vbert_checkpoint( - checkpoint: Dict[str, Any], config: W2VBertConfig -) -> Dict[str, Any]: + checkpoint: dict[str, Any], config: W2VBertConfig +) -> dict[str, Any]: """Convert a fairseq w2v-BERT checkpoint to fairseq2 format.""" # Check if we have a fairseq2 checkpoint. try: @@ -35,6 +35,7 @@ def convert_w2vbert_checkpoint( except KeyError: return checkpoint + # Check if we have a fairseq2 checkpoint. if "mlm_proj.weight" not in state_dict: return checkpoint diff --git a/src/fairseq2/models/w2vbert/model.py b/src/fairseq2/models/w2vbert/model.py index 5eab3be6b..2ad4a7b64 100644 --- a/src/fairseq2/models/w2vbert/model.py +++ b/src/fairseq2/models/w2vbert/model.py @@ -7,19 +7,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, final +from typing import final from torch import Tensor from torch.nn.functional import cross_entropy from fairseq2.models.model import Model from fairseq2.models.sequence import SequenceBatch -from fairseq2.models.wav2vec2 import ( - Wav2Vec2Features, - Wav2Vec2Loss, - Wav2Vec2Model, - Wav2Vec2Output, -) +from fairseq2.models.wav2vec2 import Wav2Vec2Loss, Wav2Vec2Model, Wav2Vec2Output from fairseq2.models.wav2vec2.masker import extract_masked_elements from fairseq2.nn import Linear from fairseq2.nn.padding import PaddingMask @@ -42,8 +37,8 @@ def __init__( num_bert_encoder_layers: int, *, num_target_codebooks: int = 1, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param w2v2_model: @@ -78,40 +73,29 @@ def forward(self, batch: SequenceBatch) -> W2VBertOutput: :param batch: The batch of sequences to process. """ - seqs, padding_mask, targets, temporal_mask = self.w2v2_model.run_frontend( - batch.seqs, batch.padding_mask - ) - - w2v2_layer_output = None - w2v2_layer_padding_mask = None + w2v2_features = self.w2v2_model.run_frontend(batch.seqs, batch.padding_mask) def hook( layer_idx: int, layer_output: Tensor, - layer_padding_mask: Optional[PaddingMask], + layer_padding_mask: PaddingMask | None, num_layers: int, ) -> bool: - nonlocal w2v2_layer_output - nonlocal w2v2_layer_padding_mask + nonlocal w2v2_features if layer_idx == num_layers - self.num_bert_encoder_layers - 1: - w2v2_layer_output = layer_output - w2v2_layer_padding_mask = layer_padding_mask + w2v2_features.seqs = layer_output return True with self.w2v2_model.encoder.register_layer_output_hook(hook): - encoder_output, _ = self.w2v2_model.encoder(seqs, padding_mask) - - assert w2v2_layer_output is not None - - features = Wav2Vec2Features( - w2v2_layer_output, w2v2_layer_padding_mask, targets, temporal_mask - ) + encoder_output, _ = self.w2v2_model.encoder( + w2v2_features.seqs, w2v2_features.padding_mask + ) - w2v2_output = self.w2v2_model.quantize_and_contrast(features) + w2v2_output = self.w2v2_model.quantize_and_contrast(w2v2_features) - seqs = extract_masked_elements(encoder_output, temporal_mask) + seqs = extract_masked_elements(encoder_output, w2v2_features.temporal_mask) bert_logits = self.final_bert_proj(seqs) @@ -178,10 +162,12 @@ def compute_loss( w2v2_loss = self.w2v2_output.compute_loss() - l1 = bert_loss_weight * bert_loss - l2 = w2v2_loss_weight * w2v2_loss.total + weighted_bert_loss = bert_loss_weight * bert_loss + weighted_w2v2_loss = w2v2_loss_weight * w2v2_loss.total - return W2VBertLoss(l1 + l2, bert_loss, w2v2_loss) + return W2VBertLoss( + weighted_bert_loss + weighted_w2v2_loss, bert_loss, w2v2_loss + ) def compute_bert_loss(self, *, label_smoothing: float = 0.0) -> Tensor: """Compute the masked prediction loss. @@ -189,11 +175,11 @@ def compute_bert_loss(self, *, label_smoothing: float = 0.0) -> Tensor: :param label_smoothing: The amount of label smoothing when computing masked prediction loss. """ + # For numerical stability in low-precision. + logits = self.bert_logits.float() + return cross_entropy( - self.bert_logits, - self.bert_targets, - reduction="sum", - label_smoothing=label_smoothing, + logits, self.bert_targets, reduction="sum", label_smoothing=label_smoothing ) @@ -210,7 +196,3 @@ class W2VBertLoss: w2v2: Wav2Vec2Loss """The loss of the wav2vec 2.0 model.""" - - def detach(self) -> W2VBertLoss: - """Return a copy detached from the autograd graph.""" - return W2VBertLoss(self.total.detach(), self.bert.detach(), self.w2v2.detach()) diff --git a/src/fairseq2/models/wav2vec2/archs.py b/src/fairseq2/models/wav2vec2/archs.py index 65845fcbc..10b7c8de6 100644 --- a/src/fairseq2/models/wav2vec2/archs.py +++ b/src/fairseq2/models/wav2vec2/archs.py @@ -84,7 +84,7 @@ def _pseudo_dinosr_base() -> Wav2Vec2Config: ) return Wav2Vec2Config( - encoder_config, + encoder_config=encoder_config, final_dim=256, final_proj_bias=True, temporal_mask_span_len=10, diff --git a/src/fairseq2/models/wav2vec2/asr/factory.py b/src/fairseq2/models/wav2vec2/asr/factory.py index 5063ec7e9..c2f51a46c 100644 --- a/src/fairseq2/models/wav2vec2/asr/factory.py +++ b/src/fairseq2/models/wav2vec2/asr/factory.py @@ -7,23 +7,23 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Final, Optional +from typing import Final from fairseq2.config_registry import ConfigRegistry from fairseq2.data import VocabularyInfo -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.wav2vec2.asr.model import Wav2Vec2AsrModel from fairseq2.models.wav2vec2.factory import ( Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig, ) -from fairseq2.models.wav2vec2.masker import Wav2Vec2Masker +from fairseq2.models.wav2vec2.masker import StandardWav2Vec2Masker, Wav2Vec2Masker from fairseq2.typing import DataType, Device WAV2VEC2_ASR_FAMILY: Final = "wav2vec2_asr" -@dataclass +@dataclass(kw_only=True) class Wav2Vec2AsrConfig: """Holds the configuration of a wav2vec 2.0 ASR model. @@ -58,7 +58,7 @@ class Wav2Vec2AsrConfig: temporal_mask_span_len: int = 10 """The length of each temporal mask span that is applied over time steps.""" - max_temporal_mask_prob: float = 0.70 + max_temporal_mask_prob: float = 0.69 """The maximum probability of masking a time step. Note that, due to mask span overlap, the effective probability will be lower.""" @@ -91,16 +91,16 @@ class Wav2Vec2AsrBuilder: _config: Wav2Vec2AsrConfig _encoder_builder: Wav2Vec2EncoderBuilder - _device: Optional[Device] - _dtype: Optional[DataType] + _device: Device | None + _dtype: DataType | None def __init__( self, config: Wav2Vec2AsrConfig, - encoder_builder: Wav2Vec2EncoderBuilder, + encoder_builder: Wav2Vec2EncoderBuilder | None = None, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -114,6 +114,11 @@ def __init__( """ self._config = config + if encoder_builder is None: + encoder_builder = Wav2Vec2EncoderBuilder( + config.encoder_config, device=device, dtype=dtype + ) + self._encoder_builder = encoder_builder self._device, self._dtype = device, dtype @@ -126,7 +131,7 @@ def build_model(self) -> Wav2Vec2AsrModel: masker = self.build_masker() - return Wav2Vec2AsrModel( + model = Wav2Vec2AsrModel( encoder_frontend, encoder, self._config.vocab_info, @@ -136,12 +141,16 @@ def build_model(self) -> Wav2Vec2AsrModel: dtype=self._dtype, ) - def build_masker(self) -> Optional[Wav2Vec2Masker]: + model.set_family(WAV2VEC2_ASR_FAMILY) + + return model + + def build_masker(self) -> Wav2Vec2Masker | None: """Build a feature masker.""" if not self._config.use_masking: return None - return Wav2Vec2Masker( + return StandardWav2Vec2Masker( self._config.encoder_config.model_dim, self._config.temporal_mask_span_len, self._config.max_temporal_mask_prob, @@ -157,30 +166,16 @@ def build_masker(self) -> Optional[Wav2Vec2Masker]: def create_wav2vec2_asr_model( config: Wav2Vec2AsrConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> Wav2Vec2AsrModel: - """Create a wav2vec 2.0 ASR model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - encoder_builder = Wav2Vec2EncoderBuilder( - config.encoder_config, device=device, dtype=dtype - ) - - builder = Wav2Vec2AsrBuilder(config, encoder_builder, device=device, dtype=dtype) - - return builder.build_model().set_family(WAV2VEC2_ASR_FAMILY) + """Create a wav2vec 2.0 ASR model.""" + return Wav2Vec2AsrBuilder(config, device=device, dtype=dtype).build_model() -create_model.register( - family=WAV2VEC2_ASR_FAMILY, - factory=create_wav2vec2_asr_model, - config_kls=Wav2Vec2AsrConfig, - arch_configs=wav2vec2_asr_archs, +model_factories.register( + WAV2VEC2_ASR_FAMILY, + create_wav2vec2_asr_model, + Wav2Vec2AsrConfig, + wav2vec2_asr_archs, ) diff --git a/src/fairseq2/models/wav2vec2/asr/loader.py b/src/fairseq2/models/wav2vec2/asr/loader.py index 482536e4a..99b04c8cd 100644 --- a/src/fairseq2/models/wav2vec2/asr/loader.py +++ b/src/fairseq2/models/wav2vec2/asr/loader.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any from fairseq2.models.config_loader import StandardModelConfigLoader from fairseq2.models.loader import StandardModelLoader, load_model @@ -20,21 +20,20 @@ from fairseq2.nn.transformer import TransformerNormOrder load_wav2vec2_asr_config = StandardModelConfigLoader( - family=WAV2VEC2_ASR_FAMILY, - config_kls=Wav2Vec2AsrConfig, - arch_configs=wav2vec2_asr_archs, + WAV2VEC2_ASR_FAMILY, Wav2Vec2AsrConfig, wav2vec2_asr_archs ) def convert_wav2vec2_asr_checkpoint( - checkpoint: Dict[str, Any], config: Wav2Vec2AsrConfig -) -> Dict[str, Any]: + checkpoint: dict[str, Any], config: Wav2Vec2AsrConfig +) -> dict[str, Any]: """Convert a fairseq wav2vec 2.0 ASR checkpoint to fairseq2 format.""" try: state_dict = checkpoint["model"] except KeyError: return checkpoint + # Check if we have a fairseq2 checkpoint. if "w2v_encoder.proj.weight" not in state_dict: return checkpoint diff --git a/src/fairseq2/models/wav2vec2/asr/model.py b/src/fairseq2/models/wav2vec2/asr/model.py index 150c977c9..b9df9e823 100644 --- a/src/fairseq2/models/wav2vec2/asr/model.py +++ b/src/fairseq2/models/wav2vec2/asr/model.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple, final +from typing import final import torch import torch.nn as nn @@ -33,8 +33,8 @@ class Wav2Vec2AsrModel(Model): model_dim: int encoder_frontend: Wav2Vec2Frontend - masker: Optional[Wav2Vec2Masker] - final_dropout: Optional[Dropout] + masker: Wav2Vec2Masker | None + final_dropout: Dropout | None final_proj: Linear target_vocab_info: VocabularyInfo @@ -44,10 +44,10 @@ def __init__( encoder: TransformerEncoder, target_vocab_info: VocabularyInfo, *, - masker: Optional[Wav2Vec2Masker] = None, + masker: Wav2Vec2Masker | None = None, final_dropout_p: float = 0.0, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param encoder_frontend: @@ -91,7 +91,7 @@ def forward(self, batch: SequenceBatch) -> Wav2Vec2AsrOutput: :param batch: The batch of sequences to process. """ - seqs, padding_mask = self.encoder_frontend.extract_features( + seqs, padding_mask, _ = self.encoder_frontend.extract_features( batch.seqs, batch.padding_mask ) @@ -125,13 +125,13 @@ class Wav2Vec2AsrOutput: where :math:`N` is the batch size, :math:`S_{out}` is the output sequence length, and :math:`T` is the size of the vocabulary.""" - padding_mask: Optional[PaddingMask] + padding_mask: PaddingMask | None """The padding mask of :attr:`logits`. *Shape:* :math:`(N,S_{out})`, where :math:`N` is the batch size and :math:`S_{out}` is the output sequence length.""" def compute_loss( - self, targets: Tensor, target_padding_mask: Optional[PaddingMask] + self, targets: Tensor, target_padding_mask: PaddingMask | None ) -> Tensor: """Compute the CTC (Connectionist Temporal Classification) loss. @@ -169,7 +169,7 @@ def compute_loss( def generate_hypotheses( self, pad_idx: int, blank_label: int = 0 - ) -> Tuple[Tensor, Optional[PaddingMask]]: + ) -> tuple[Tensor, PaddingMask | None]: """Generate hypotheses using greedy search. :param pad_idx: diff --git a/src/fairseq2/models/wav2vec2/factory.py b/src/fairseq2/models/wav2vec2/factory.py index a98684695..47fad5b32 100644 --- a/src/fairseq2/models/wav2vec2/factory.py +++ b/src/fairseq2/models/wav2vec2/factory.py @@ -7,20 +7,20 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Final, List, Optional, Tuple +from typing import Final from torch.nn import GELU, SiLU from fairseq2.config_registry import ConfigRegistry from fairseq2.models.conformer import ConformerBlock, ConformerConvolution -from fairseq2.models.factory import create_model +from fairseq2.models.factory import model_factories from fairseq2.models.feature_extractor import SequenceFeatureExtractor from fairseq2.models.wav2vec2.feature_extractor import ( Wav2Vec2FbankFeatureExtractor, Wav2Vec2FeatureExtractor, ) from fairseq2.models.wav2vec2.frontend import Wav2Vec2Frontend -from fairseq2.models.wav2vec2.masker import Wav2Vec2Masker +from fairseq2.models.wav2vec2.masker import StandardWav2Vec2Masker, Wav2Vec2Masker from fairseq2.models.wav2vec2.model import Wav2Vec2Model from fairseq2.models.wav2vec2.position_encoder import ( Wav2Vec2PositionEncoder, @@ -31,6 +31,7 @@ VectorQuantizer, ) from fairseq2.nn import PositionEncoder, RotaryEncoder +from fairseq2.nn.projection import init_bert_projection from fairseq2.nn.transformer import ( SDPA, FeedForwardNetwork, @@ -51,7 +52,7 @@ WAV2VEC2_FAMILY: Final = "wav2vec2" -@dataclass +@dataclass(kw_only=True) class Wav2Vec2Config: """Holds the configuration of a wav2vec 2.0 model. @@ -71,11 +72,16 @@ class Wav2Vec2Config: final_proj_bias: bool = True """If ``True``, the final projection learns an additive bias.""" + quantizer_encoder_grad: bool = True + """If ``True``, gradients are propagated from the quantizer through the convolutional + encoder. Otherwise, they are detached and the encoder is only trained with gradients + from the transformer. """ + # Mask temporal_mask_span_len: int = 10 """The length of each temporal mask span that is applied over time steps.""" - max_temporal_mask_prob: float = 0.65 + max_temporal_mask_prob: float = 0.69 """The maximum probability of masking a time step. Note that, due to mask span overlap, the effective probability will be lower.""" @@ -102,7 +108,7 @@ class Wav2Vec2Config: num_codebook_entries: int = 320 """The number of entries per codebook.""" - codebook_sampling_temperature: Tuple[float, float, float] = (2.0, 0.5, 0.999995) + codebook_sampling_temperature: tuple[float, float, float] = (2.0, 0.5, 0.999995) """A tuple of start temperature, end temperature, and decay factor for codebook entry sampling.""" @@ -119,7 +125,7 @@ class Wav2Vec2Config: wav2vec2_arch = wav2vec2_archs.decorator -@dataclass +@dataclass(kw_only=True) class Wav2Vec2EncoderConfig: """Holds the configuration of a wav2vec 2.0 encoder. @@ -148,7 +154,7 @@ class Wav2Vec2EncoderConfig: """If ``True``, applies Layer Normalization to extracted features.""" # Waveform Feature Extractor - feature_extractor_layer_descs: List[Tuple[int, int, int]] = field( + feature_extractor_layer_descs: list[tuple[int, int, int]] = field( default_factory=lambda: [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 ) """A tuple of output dimension, kernel size, and stride for each feature @@ -237,16 +243,16 @@ class Wav2Vec2Builder: _config: Wav2Vec2Config _encoder_builder: Wav2Vec2EncoderBuilder - _device: Optional[Device] - _dtype: Optional[DataType] + _device: Device | None + _dtype: DataType | None def __init__( self, config: Wav2Vec2Config, - encoder_builder: Wav2Vec2EncoderBuilder, + encoder_builder: Wav2Vec2EncoderBuilder | None = None, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -260,6 +266,11 @@ def __init__( """ self._config = config + if encoder_builder is None: + encoder_builder = Wav2Vec2EncoderBuilder( + config.encoder_config, device=device, dtype=dtype + ) + self._encoder_builder = encoder_builder self._device, self._dtype = device, dtype @@ -274,7 +285,7 @@ def build_model(self) -> Wav2Vec2Model: quantizer = self.build_quantizer() - return Wav2Vec2Model( + model = Wav2Vec2Model( encoder_frontend, encoder, masker, @@ -283,13 +294,18 @@ def build_model(self) -> Wav2Vec2Model: final_proj_bias=self._config.final_proj_bias, num_distractors=self._config.num_distractors, logit_temp=self._config.logit_temp, + quantizer_encoder_grad=self._config.quantizer_encoder_grad, device=self._device, dtype=self._dtype, ) + model.set_family(WAV2VEC2_FAMILY) + + return model + def build_masker(self) -> Wav2Vec2Masker: """Build a feature masker.""" - return Wav2Vec2Masker( + return StandardWav2Vec2Masker( self._config.encoder_config.model_dim, self._config.temporal_mask_span_len, self._config.max_temporal_mask_prob, @@ -323,16 +339,16 @@ class Wav2Vec2EncoderBuilder: """ _config: Wav2Vec2EncoderConfig - _device: Optional[Device] - _dtype: Optional[DataType] - _rel_pos_encoding: Optional[RelativePositionalEncoding] + _device: Device | None + _dtype: DataType | None + _rel_pos_encoding: RelativePositionalEncoding | None def __init__( self, config: Wav2Vec2EncoderConfig, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param config: @@ -371,7 +387,7 @@ def build_frontend(self) -> Wav2Vec2Frontend: dtype=self._dtype, ) - def build_feature_extractor(self) -> Optional[SequenceFeatureExtractor]: + def build_feature_extractor(self) -> SequenceFeatureExtractor | None: """Build a feature extractor.""" if self._config.use_fbank: return Wav2Vec2FbankFeatureExtractor( @@ -389,7 +405,7 @@ def build_feature_extractor(self) -> Optional[SequenceFeatureExtractor]: dtype=self._dtype, ) - def build_position_encoder(self) -> Optional[PositionEncoder]: + def build_position_encoder(self) -> PositionEncoder | None: """Build a position encoder.""" if self._config.pos_encoder_type != "conv": return None @@ -480,8 +496,10 @@ def build_attention(self) -> MultiheadAttention: return StandardMultiheadAttention( self._config.model_dim, self._config.num_encoder_attn_heads, + qkv_proj_init_fn=init_bert_projection, pos_encoder=pos_encoder, sdpa=sdpa, + output_proj_init_fn=init_bert_projection, device=self._device, dtype=self._dtype, ) @@ -525,6 +543,7 @@ def build_ffn(self, use_swish: bool = False) -> FeedForwardNetwork: inner_activation=SiLU() if use_swish else GELU(), inner_dropout_p=self._config.ffn_inner_dropout_p, norm_order=self._config.norm_order, + proj_init_fn=init_bert_projection, device=self._device, dtype=self._dtype, ) @@ -533,30 +552,13 @@ def build_ffn(self, use_swish: bool = False) -> FeedForwardNetwork: def create_wav2vec2_model( config: Wav2Vec2Config, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> Wav2Vec2Model: - """Create a wav2vec 2.0 model. - - :param config: - The configuration. - :param device: - The device on which to initialize modules. - :param dtype: - The data type of module parameters and buffers. - """ - encoder_builder = Wav2Vec2EncoderBuilder( - config.encoder_config, device=device, dtype=dtype - ) - - builder = Wav2Vec2Builder(config, encoder_builder, device=device, dtype=dtype) - - return builder.build_model().set_family(WAV2VEC2_FAMILY) + """Create a wav2vec 2.0 model.""" + return Wav2Vec2Builder(config, device=device, dtype=dtype).build_model() -create_model.register( - family=WAV2VEC2_FAMILY, - factory=create_wav2vec2_model, - config_kls=Wav2Vec2Config, - arch_configs=wav2vec2_archs, +model_factories.register( + WAV2VEC2_FAMILY, create_wav2vec2_model, Wav2Vec2Config, wav2vec2_archs ) diff --git a/src/fairseq2/models/wav2vec2/feature_extractor.py b/src/fairseq2/models/wav2vec2/feature_extractor.py index f408ffe85..553b9acf2 100644 --- a/src/fairseq2/models/wav2vec2/feature_extractor.py +++ b/src/fairseq2/models/wav2vec2/feature_extractor.py @@ -6,19 +6,21 @@ from __future__ import annotations -from typing import Optional, Sequence, Tuple, final +from collections.abc import Sequence +from typing import final import torch import torch.nn as nn from torch import Tensor from torch.nn import GELU, Conv1d, Dropout, GroupNorm, Module, Sequential from torch.nn.functional import group_norm, layer_norm +from typing_extensions import override from fairseq2.models.feature_extractor import SequenceFeatureExtractor from fairseq2.nn import LayerNorm from fairseq2.nn.padding import PaddingMask from fairseq2.nn.utils.gradient import scale_gradient -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -28,21 +30,21 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): :cite:t:`https://doi.org/10.48550/arxiv.2006.11477`.""" layers: Sequential - layer_descs: Sequence[Tuple[int, int, int]] + layer_descs: Sequence[tuple[int, int, int]] num_channels: int gradient_scale: float def __init__( self, - layer_descs: Sequence[Tuple[int, int, int]], + layer_descs: Sequence[tuple[int, int, int]], bias: bool, *, num_channels: int = 1, dropout_p: float = 0.0, layer_norm: bool = False, gradient_scale: float = 1.0, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param layer_descs: @@ -62,14 +64,14 @@ def __init__( value less than 1.0 allows the feature extractor to learn at a lower rate than the rest of the model. """ + if len(layer_descs) == 0: + raise ValueError("`layer_descs` must be non-empty.") + # The output dimensionality of the last feature extraction layer. feature_dim = layer_descs[-1][0] super().__init__(feature_dim) - if len(layer_descs) == 0: - raise ValueError("`layer_descs` must be non-empty.") - self.layers = Sequential() if num_channels < 1: @@ -132,8 +134,8 @@ def __init__( @override def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: """See the base :meth:`SequenceFeatureExtractor.forward`. :param seqs: @@ -195,9 +197,9 @@ class Wav2Vec2FeatureExtractionLayer(Module): :class:`Wav2Vec2FeatureExtractor`.""" conv: Conv1d - dropout: Optional[Dropout] - group_norm: Optional[GroupNorm] - layer_norm: Optional[LayerNorm] + dropout: Dropout | None + group_norm: GroupNorm | None + layer_norm: LayerNorm | None activation: GELU def __init__( @@ -209,10 +211,10 @@ def __init__( bias: bool, *, dropout_p: float = 0.0, - group_norm: Optional[GroupNorm] = None, - layer_norm: Optional[LayerNorm] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + group_norm: GroupNorm | None = None, + layer_norm: LayerNorm | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: super().__init__() @@ -300,8 +302,8 @@ def __init__( @override def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: """See the base :meth:`SequenceFeatureExtractor.forward`. :param seqs: diff --git a/src/fairseq2/models/wav2vec2/frontend.py b/src/fairseq2/models/wav2vec2/frontend.py index efcc910bf..a241a01c6 100644 --- a/src/fairseq2/models/wav2vec2/frontend.py +++ b/src/fairseq2/models/wav2vec2/frontend.py @@ -6,18 +6,20 @@ from __future__ import annotations -from typing import Optional, Tuple, final +from typing import final from torch import Tensor from torch.nn import Dropout +from typing_extensions import override +from fairseq2.error import NotSupportedError from fairseq2.models.feature_extractor import SequenceFeatureExtractor from fairseq2.models.transformer import TransformerFrontend from fairseq2.models.wav2vec2.masker import Wav2Vec2Masker from fairseq2.nn import LayerNorm, Linear, PositionEncoder, StandardLayerNorm from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -26,26 +28,26 @@ class Wav2Vec2Frontend(TransformerFrontend): :cite:t:`https://doi.org/10.48550/arxiv.2006.11477`.""" feature_dim: int - feature_extractor: Optional[SequenceFeatureExtractor] + feature_extractor: SequenceFeatureExtractor | None post_extract_layer_norm: LayerNorm - model_dim_proj: Optional[Linear] - first_pass_dropout: Optional[Dropout] - pos_encoder: Optional[PositionEncoder] - layer_norm: Optional[LayerNorm] - dropout: Optional[Dropout] + model_dim_proj: Linear | None + first_pass_dropout: Dropout | None + pos_encoder: PositionEncoder | None + layer_norm: LayerNorm | None + dropout: Dropout | None def __init__( self, model_dim: int, feature_dim: int, - feature_extractor: Optional[SequenceFeatureExtractor], - pos_encoder: Optional[PositionEncoder], + feature_extractor: SequenceFeatureExtractor | None, + pos_encoder: PositionEncoder | None, *, first_pass_dropout_p: float = 0.0, layer_norm: bool = False, dropout_p: float = 0.0, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -71,9 +73,9 @@ def __init__( self.feature_dim = feature_dim if feature_extractor is not None: - if feature_dim != feature_extractor.feature_dim: + if feature_extractor.feature_dim != feature_dim: raise ValueError( - f"`feature_dim` of `feature_extractor` must be equal to `feature_dim` ({feature_dim}), but is {feature_extractor.feature_dim} instead." + f"`feature_extractor.feature_dim` must be equal to `feature_dim` ({feature_dim}), but is {feature_extractor.feature_dim} instead." ) self.feature_extractor = feature_extractor @@ -99,7 +101,7 @@ def __init__( if pos_encoder is not None: if pos_encoder.encoding_dim != model_dim: raise ValueError( - f"`encoding_dim` of `pos_encoder` must be equal to `model_dim` ({model_dim}), but is {pos_encoder.encoding_dim} instead." + f"`pos_encoder.encoding_dim` must be equal to `model_dim` ({model_dim}), but is {pos_encoder.encoding_dim} instead." ) self.pos_encoder = pos_encoder @@ -122,24 +124,24 @@ def __init__( def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: if state_bag is not None: - raise ValueError( - "`Wav2Vec2Frontend` does not support incremental decoding." + raise NotSupportedError( + f"`{type(self)}` does not support incremental decoding." ) - seqs, padding_mask = self.extract_features(seqs, padding_mask) + seqs, padding_mask, _ = self.extract_features(seqs, padding_mask) seqs, padding_mask, _ = self.process_features(seqs, padding_mask) return seqs, padding_mask def extract_features( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None, Tensor]: """Extract features from the specified sequences. :param seqs: @@ -152,32 +154,36 @@ def extract_features( is the batch size and :math:`S` is the sequence length. :returns: - - The extracted features. *Shape:* :math:`(N,S_{out},F)`, where + - The normalized features. *Shape:* :math:`(N,S_{out},E)`, where :math:`N` is the batch size, :math:`S_{out}` is the output - sequence length, and :math:`F` is the dimensionality of the + sequence length, and :math:`E` is the dimensionality of the extracted features. - The padding mask of the extracted features. *Shape:* :math:`(N,S_{out})`, where :math:`N` is the batch size and :math:`S_{out}` is the output sequence length. + - The raw features. *Shape*: Same as the normalized features (i.e. + first element of the returned tuple). """ if self.feature_extractor is not None: seqs, padding_mask = self.feature_extractor(seqs, padding_mask) + raw_features = seqs.clone() + seqs = self.post_extract_layer_norm(seqs) - return seqs, padding_mask + return seqs, padding_mask, raw_features def process_features( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - masker: Optional[Wav2Vec2Masker] = None, - ) -> Tuple[Tensor, Optional[PaddingMask], Tensor]: + padding_mask: PaddingMask | None, + masker: Wav2Vec2Masker | None = None, + ) -> tuple[Tensor, PaddingMask | None, Tensor | None]: """Process extracted features. :param seqs: - The features to process. *Shape:* :math:`(N,S,F)`, where :math:`N` - is the batch size, :math:`S` is the sequence length, and :math:`F` + The features to process. *Shape:* :math:`(N,S,E)`, where :math:`N` + is the batch size, :math:`S` is the sequence length, and :math:`E` is the dimensionality of the features. :param padding_mask: The padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where :math:`N` diff --git a/src/fairseq2/models/wav2vec2/loader.py b/src/fairseq2/models/wav2vec2/loader.py index b5fb9b19b..a793ba551 100644 --- a/src/fairseq2/models/wav2vec2/loader.py +++ b/src/fairseq2/models/wav2vec2/loader.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import torch @@ -22,19 +22,20 @@ from fairseq2.nn.transformer import TransformerNormOrder load_wav2vec2_config = StandardModelConfigLoader( - family=WAV2VEC2_FAMILY, config_kls=Wav2Vec2Config, arch_configs=wav2vec2_archs + WAV2VEC2_FAMILY, Wav2Vec2Config, wav2vec2_archs ) def convert_wav2vec2_checkpoint( - checkpoint: Dict[str, Any], config: Wav2Vec2Config -) -> Dict[str, Any]: + checkpoint: dict[str, Any], config: Wav2Vec2Config +) -> dict[str, Any]: """Convert a fairseq wav2vec 2.0 checkpoint to fairseq2 format.""" try: state_dict = checkpoint["model"] except KeyError: return checkpoint + # Check if we have a fairseq2 checkpoint. if "project_q.weight" not in state_dict: return checkpoint diff --git a/src/fairseq2/models/wav2vec2/masker.py b/src/fairseq2/models/wav2vec2/masker.py index ccadda1a5..a45deae54 100644 --- a/src/fairseq2/models/wav2vec2/masker.py +++ b/src/fairseq2/models/wav2vec2/masker.py @@ -6,23 +6,51 @@ from __future__ import annotations -from typing import Optional, Tuple, final +from abc import ABC, abstractmethod +from typing import final import torch import torch.nn as nn from torch import Tensor from torch.nn import Module, Parameter +from typing_extensions import override from fairseq2.nn.padding import PaddingMask -from fairseq2.nn.utils.mask import compute_row_mask +from fairseq2.nn.utils.mask import RowMaskFactory, compute_row_mask from fairseq2.typing import DataType, Device +class Wav2Vec2Masker(Module, ABC): + """Masks extracted wav2vec 2.0 features.""" + + @abstractmethod + def forward( + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, Tensor]: + """ + :param seqs: + The sequences to mask. *Shape:* :math:`(N,S,M)`, where :math:`N` is + the batch size, :math:`S` is the sequence length, and :math:`M` is + the dimensionality of the model. + :param seq_lens: + An array where each element represents the length of the sequence at + the same index in ``seqs``. *Shape:* :math:`(N)`, where :math:`N` is + the batch size. + + :returns: + - The input sequences with mask applied. *Shape:* Same as ``seqs``. + - The temporal mask that has been applied to ``seqs``. *Shape:* + :math:`(N,S)`, where :math:`N` is the batch size and :math`S` is + the sequence length. + """ + + @final -class Wav2Vec2Masker(Module): - """Masks extracted features as described in Section 3.1 of +class StandardWav2Vec2Masker(Wav2Vec2Masker): + """Masks extracted wav2vec 2.0 features as described in Section 3.1 of :cite:t:`https://doi.org/10.48550/arxiv.2006.11477`.""" + mask_factory: RowMaskFactory temporal_span_len: int max_temporal_mask_prob: float temporal_mask_embed: Parameter @@ -39,8 +67,9 @@ def __init__( max_spatial_mask_prob: float = 0.0, min_num_spatial_mask_spans: int = 2, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + mask_factory: RowMaskFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -56,9 +85,14 @@ def __init__( :param max_spatial_mask_prob: The maximum probability of masking a feature. Note that, due to mask span overlap, the effective probability will be lower. + :param mask_factory: + The row mask factory. If ``None``, :func:`compute_row_mask` will be + used. """ super().__init__() + self.mask_factory = mask_factory or compute_row_mask + if max_temporal_mask_prob == 0.0: raise ValueError("`max_temporal_mask_prob` must be greater than 0.") @@ -80,29 +114,14 @@ def reset_parameters(self) -> None: """Reset the parameters and buffers of the module.""" nn.init.uniform_(self.temporal_mask_embed) + @override def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Tensor]: - """ - :param seqs: - The sequences to mask. *Shape:* :math:`(N,S,M)`, where :math:`N` is - the batch size, :math:`S` is the sequence length, and :math:`M` is - the dimensionality of the model. - :param seq_lens: - An array where each element represents the length of the sequence at - the same index in ``seqs``. *Shape:* :math:`(N)`, where :math:`N` is - the batch size. - - :returns: - - The input sequences with mask applied. *Shape:* Same as ``seqs``. - - The temporal mask that has been applied to ``seqs``. *Shape:* - :math:`(N,S)`, where :math:`N` is the batch size and :math`S` is - the sequence length. - """ + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, Tensor]: batch_size, seq_len, model_dim = seqs.shape # Temporal mask over time steps. - temporal_mask = compute_row_mask( + temporal_mask = self.mask_factory( shape=(batch_size, seq_len), span_len=self.temporal_span_len, max_mask_prob=self.max_temporal_mask_prob, @@ -118,7 +137,7 @@ def forward( if self.max_spatial_mask_prob > 0.0: # Spatial mask over features. # (N, M) - spatial_mask = compute_row_mask( + spatial_mask = self.mask_factory( shape=(batch_size, model_dim), span_len=self.spatial_span_len, max_mask_prob=self.max_spatial_mask_prob, @@ -137,7 +156,7 @@ def forward( def extra_repr(self) -> str: """:meta private:""" - return ( + s = ( f"temporal_span_len={self.temporal_span_len}, " f"max_temporal_mask_prob={self.max_temporal_mask_prob}, " f"min_num_temporal_mask_spans={self.min_num_temporal_mask_spans}, " @@ -146,6 +165,13 @@ def extra_repr(self) -> str: f"min_num_spatial_mask_spans={self.min_num_spatial_mask_spans}" ) + if self.mask_factory is not compute_row_mask: + mask_factory = getattr(self.mask_factory, "__name__", self.mask_factory) + + s = f"{s}, mask_factory={mask_factory}" + + return s + def extract_masked_elements(seqs: Tensor, temporal_mask: Tensor) -> Tensor: """Extract masked elements from ``seqs``. diff --git a/src/fairseq2/models/wav2vec2/model.py b/src/fairseq2/models/wav2vec2/model.py index 04fd39d53..9f457c4c9 100644 --- a/src/fairseq2/models/wav2vec2/model.py +++ b/src/fairseq2/models/wav2vec2/model.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple, final +from typing import final import torch from torch import Tensor @@ -54,8 +54,9 @@ def __init__( final_proj_bias: bool = True, num_distractors: int = 100, logit_temp: float = 0.1, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + quantizer_encoder_grad: bool = True, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param encoder_frontend: @@ -108,6 +109,7 @@ def __init__( self.num_distractors = num_distractors self.logit_temp = logit_temp + self.quantizer_encoder_grad = quantizer_encoder_grad def forward(self, batch: SequenceBatch) -> Wav2Vec2Output: """ @@ -124,19 +126,17 @@ def extract_features(self, batch: SequenceBatch) -> Wav2Vec2Features: :param batch: The batch of sequences to process. """ - seqs, padding_mask, targets, temporal_mask = self.run_frontend( - batch.seqs, batch.padding_mask - ) - - encoder_output, encoder_padding_mask = self.encoder(seqs, padding_mask) + features = self.run_frontend(batch.seqs, batch.padding_mask) - return Wav2Vec2Features( - encoder_output, encoder_padding_mask, targets, temporal_mask + features.seqs, features.padding_mask = self.encoder( + features.seqs, features.padding_mask ) + return features + def run_frontend( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask], Tensor, Tensor]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> Wav2Vec2Features: """Run the encoder frontend in pretraining mode. :param seqs: @@ -146,30 +146,17 @@ def run_frontend( :param padding_mask: The padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where :math:`N` is the batch size and :math:`S` is the sequence length. - - :returns: - - The processed features to pass to the context network. - *Shape:* :math:`(N,S_{out},M)`, where :math:`N` is the batch size, - :math:`S_{out}` is the output sequence length, and :math:`M` is - the dimensionality of the model. - - The padding mask of the processed features. *Shape:* - :math:`(N,S_{out})`, where :math:`N` is the batch size and - :math:`S_{out}` is the output sequence length. - - The non-quantized context network targets that have been extracted - from the input sequences. *Shape:* :math:`(N,S_{msk},M)`, where - :math:`N` is the batch size, :math:`S_{msk}` is the masked - sequence length, and :math:`M` is the dimensionality of the model. - - The temporal mask that has been applied to extract the context - network targets. *Shape:* :math:`(N,S_{out})`, where :math:`N` is - the batch size and :math`S_{out}` is the output sequence length. """ frontend = self.encoder_frontend - seqs, padding_mask = frontend.extract_features(seqs, padding_mask) + seqs, padding_mask, raw_features = frontend.extract_features(seqs, padding_mask) # We use the extracted features as context network targets after masking # and quantization. - targets = seqs.detach().clone() + if self.quantizer_encoder_grad: + targets = seqs.clone() + else: + targets = seqs.detach().clone() if frontend.first_pass_dropout is not None: targets = frontend.first_pass_dropout(targets) @@ -182,7 +169,9 @@ def run_frontend( targets = extract_masked_elements(targets, temporal_mask) - return seqs, padding_mask, targets, temporal_mask + return Wav2Vec2Features( + seqs, padding_mask, targets, temporal_mask, raw_features + ) def quantize_and_contrast(self, features: Wav2Vec2Features) -> Wav2Vec2Output: """Quantize targets and produce logits for contrastive prediction. @@ -191,8 +180,8 @@ def quantize_and_contrast(self, features: Wav2Vec2Features) -> Wav2Vec2Output: The extracted features from the encoder. """ encoder_output, encoder_padding_mask, targets, temporal_mask = ( - features.encoder_output, - features.encoder_padding_mask, + features.seqs, + features.padding_mask, features.targets, features.temporal_mask, ) @@ -216,6 +205,7 @@ def quantize_and_contrast(self, features: Wav2Vec2Features) -> Wav2Vec2Output: quantizer_output, encoder_output, encoder_padding_mask, + features.raw, ) def _sample_distractors(self, targets: Tensor) -> Tensor: @@ -286,7 +276,7 @@ def _compute_logits( if distractor_is_target.any(): logits[:, :, 1:][distractor_is_target] = -torch.inf - return logits.type_as(seqs) + return logits def extra_repr(self) -> str: """:meta private:""" @@ -302,15 +292,15 @@ def extra_repr(self) -> str: class Wav2Vec2Features: """Holds the extracted features of a wav2vec 2.0 model.""" - encoder_output: Tensor - """The context network output. *Shape:* :math:`(N,S_{enc},M)`, where - :math:`N` is the batch size, :math:`S_{enc}` is the encoder output sequence - length, and :math:`M` is the dimensionality of the model.""" + seqs: Tensor + """The features. *Shape:* :math:`(N,S_{enc},M)`, where :math:`N` is the + batch size, :math:`S_{out}` is the output sequence length, and :math:`M` is + the dimensionality of the model.""" - encoder_padding_mask: Optional[PaddingMask] - """The padding mask of :attr:`encoder_output`. *Shape:* :math:`(N,S_{enc})`, - where :math:`N` is the batch size and :math:`S_{enc}` is the encoder output - sequence length.""" + padding_mask: PaddingMask | None + """The padding mask of :attr:`seqs`. *Shape:* :math:`(N,S_{out})`, where + :math:`N` is the batch size and :math:`S_{out}` is the output sequence + length.""" targets: Tensor """The non-quantized context network targets that have been extracted from @@ -323,6 +313,9 @@ class Wav2Vec2Features: targets. *Shape:* :math:`(N,S_{enc})`, where :math:`N` is the batch size and :math`S_{enc}` is the encoder output sequence length.""" + raw: Tensor + """The raw features returned by the frontend. *Shape*: Same as :attr:`seqs`.""" + @final @dataclass @@ -354,24 +347,38 @@ class Wav2Vec2Output: :math:`N` is the batch size, :math:`S_{enc}` is the encoder output sequence length, and :math:`M` is the dimensionality of the model.""" - encoder_padding_mask: Optional[PaddingMask] + encoder_padding_mask: PaddingMask | None """The padding mask of :attr:`encoder_output`. *Shape:* :math:`(N,S_{enc})`, where :math:`N` is the batch size and :math:`S_{enc}` is the encoder output sequence length.""" - def compute_loss(self, *, diversity_loss_weight: float = 0.1) -> Wav2Vec2Loss: + raw_features: Tensor + """The raw features returned by the frontend. *Shape*: Same as + :attr:`encoder_output`.""" + + def compute_loss( + self, diversity_loss_weight: float = 0.1, feature_penalty_weight: float = 10.0 + ) -> Wav2Vec2Loss: """Compute the loss. :param diversity_loss_weight: The weight of diversity in loss computation. + :param feature_penalty_weight: + The weight of the feature penalty in loss computation. """ contrastive_loss = self.compute_contrastive_loss() diversity_loss = self.compute_diversity_loss() - total_loss = contrastive_loss + diversity_loss_weight * diversity_loss + feature_penalty = self.compute_feature_penalty() + + weighted_diversity_loss = diversity_loss_weight * diversity_loss + + weighted_feature_penalty = feature_penalty_weight * feature_penalty - return Wav2Vec2Loss(total_loss, contrastive_loss, diversity_loss) + loss = contrastive_loss + weighted_diversity_loss + weighted_feature_penalty + + return Wav2Vec2Loss(loss, contrastive_loss, diversity_loss, feature_penalty) def compute_contrastive_loss(self) -> Tensor: """Compute the contrastive loss.""" @@ -380,6 +387,9 @@ def compute_contrastive_loss(self) -> Tensor: # (N, S, L) -> (S x N, L) logits = self.logits.transpose(0, 1).reshape(-1, num_logits) + # For numerical stability in low-precision. + logits = logits.float() + # The target is always at index 0 in the candidate list. target_indices = logits.new_zeros((batch_size * seq_len,), dtype=torch.int64) @@ -391,6 +401,12 @@ def compute_diversity_loss(self) -> Tensor: return self.quantizer_output.compute_loss() * batch_size * seq_len + def compute_feature_penalty(self) -> Tensor: + """Compute the feature penalty.""" + batch_size, seq_len = self.logits.shape[:2] + + return self.raw_features.float().pow(2).mean() * batch_size * seq_len + @final @dataclass @@ -406,8 +422,5 @@ class Wav2Vec2Loss: diversity: Tensor """The diversity loss. *Shape:* :math:`()`.""" - def detach(self) -> Wav2Vec2Loss: - """Return a copy detached from the autograd graph.""" - return Wav2Vec2Loss( - self.total.detach(), self.contrastive.detach(), self.diversity.detach() - ) + feature_penalty: Tensor + """The feature penalty. *Shape:* :math:`()`.""" diff --git a/src/fairseq2/models/wav2vec2/position_encoder.py b/src/fairseq2/models/wav2vec2/position_encoder.py index 4886a7d50..5ec379352 100644 --- a/src/fairseq2/models/wav2vec2/position_encoder.py +++ b/src/fairseq2/models/wav2vec2/position_encoder.py @@ -7,20 +7,19 @@ from __future__ import annotations import warnings -from typing import Optional, final +from typing import final from warnings import catch_warnings -import torch import torch.nn as nn from torch import Tensor from torch.nn import GELU, Conv1d, Module, Sequential from torch.nn.utils import remove_weight_norm, weight_norm # type: ignore[attr-defined] +from typing_extensions import override from fairseq2.nn import LayerNorm, PositionEncoder, StandardLayerNorm from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask, apply_padding_mask -from fairseq2.typing import DataType, Device, override -from fairseq2.utils.version import torch_greater_or_equal +from fairseq2.typing import DataType, Device @final @@ -38,8 +37,8 @@ def __init__( kernel_size: int, num_groups: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -69,8 +68,8 @@ def __init__( def _do_forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: """:meta private:""" if state_bag is not None: @@ -116,11 +115,6 @@ def reset_parameters(self) -> None: except AttributeError: weight = self.weight - if weight.dtype == torch.bfloat16 and not torch_greater_or_equal(2, 2): - raise RuntimeError( - "`torch.nn.utils.weight_norm()` supports `torch.bfloat16` only in PyTorch 2.2 and later versions." - ) - nn.init.normal_( self.weight, mean=0.0, std=(4.0 / (kernel_size * model_dim)) ** 0.5 ) @@ -157,8 +151,8 @@ def __init__( num_groups: int, num_layers: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -192,8 +186,8 @@ def __init__( def _do_forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: """:meta private:""" if state_bag is not None: @@ -231,8 +225,8 @@ def __init__( kernel_size: int, num_groups: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: super().__init__() diff --git a/src/fairseq2/models/wav2vec2/vector_quantizer.py b/src/fairseq2/models/wav2vec2/vector_quantizer.py index dfb1e2b46..f61820b92 100644 --- a/src/fairseq2/models/wav2vec2/vector_quantizer.py +++ b/src/fairseq2/models/wav2vec2/vector_quantizer.py @@ -8,16 +8,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Tuple, final import torch import torch.nn as nn from torch import Tensor from torch.nn import Module, Parameter from torch.nn.functional import gumbel_softmax +from typing_extensions import override from fairseq2.nn import Linear -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device class VectorQuantizer(Module, ABC): @@ -61,7 +61,6 @@ def get_target_indices(self, num_codebooks: int) -> Tensor: pass -@final class GumbelVectorQuantizer(VectorQuantizer): """Quantizes incoming data using Gumbel-Softmax.""" @@ -83,9 +82,9 @@ def __init__( num_codebooks: int, num_codebook_entries: int, *, - codebook_sampling_temperature: Tuple[float, float, float], - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + codebook_sampling_temperature: tuple[float, float, float], + device: Device | None = None, + dtype: DataType | None = None, ): """ :param input_dim: @@ -143,7 +142,7 @@ def reset_parameters(self) -> None: self.num_updates.zero_() @override - def forward(self, x: Tensor) -> "GumbelVectorQuantizerOutput": + def forward(self, x: Tensor) -> GumbelVectorQuantizerOutput: current_temp = self._compute_current_temp() bsz, tsz, fsz = x.shape @@ -188,10 +187,17 @@ def forward(self, x: Tensor) -> "GumbelVectorQuantizerOutput": cb = x - x = x.unsqueeze(-1) * self.entries - x = x.view(bsz * tsz, self.num_codebooks, self.num_codebook_entries, -1) - x = x.sum(-2) - x = x.view(bsz, tsz, -1) + @torch.compile(fullgraph=True) + def compute_sum(x: torch.Tensor) -> torch.Tensor: + return torch.sum( + x.view(bsz * tsz, self.num_codebooks, self.num_codebook_entries, 1) + * self.entries.view( + 1, self.num_codebooks, self.num_codebook_entries, -1 + ), + dim=-2, + ) + + x = compute_sum(x).view(bsz, tsz, -1) return GumbelVectorQuantizerOutput( x, @@ -220,7 +226,6 @@ def init_entry_projection(proj: Linear) -> None: nn.init.zeros_(proj.bias) -@final @dataclass class GumbelVectorQuantizerOutput(VectorQuantizerOutput): cb: Tensor diff --git a/src/fairseq2/nn/__init__.py b/src/fairseq2/nn/__init__.py index b94868c59..c637eb2b3 100644 --- a/src/fairseq2/nn/__init__.py +++ b/src/fairseq2/nn/__init__.py @@ -16,11 +16,23 @@ from fairseq2.nn.normalization import LayerNorm as LayerNorm from fairseq2.nn.normalization import RMSNorm as RMSNorm from fairseq2.nn.normalization import StandardLayerNorm as StandardLayerNorm +from fairseq2.nn.position_encoder import ( + InterpolatedPositionEncoder as InterpolatedPositionEncoder, +) from fairseq2.nn.position_encoder import ( LearnedPositionEncoder as LearnedPositionEncoder, ) from fairseq2.nn.position_encoder import PositionEncoder as PositionEncoder from fairseq2.nn.position_encoder import RotaryEncoder as RotaryEncoder +from fairseq2.nn.position_encoder import ( + Sinusoidal2dPositionEncoder as Sinusoidal2dPositionEncoder, +) +from fairseq2.nn.position_encoder import ( + Sinusoidal3dPositionEncoder as Sinusoidal3dPositionEncoder, +) +from fairseq2.nn.position_encoder import ( + SinusoidalNdPositionEncoder as SinusoidalNdPositionEncoder, +) from fairseq2.nn.position_encoder import ( SinusoidalPositionEncoder as SinusoidalPositionEncoder, ) diff --git a/src/fairseq2/nn/ddp.py b/src/fairseq2/nn/ddp.py index 456e3a86e..57e285f11 100644 --- a/src/fairseq2/nn/ddp.py +++ b/src/fairseq2/nn/ddp.py @@ -6,8 +6,6 @@ from __future__ import annotations -from typing import List - import torch.distributed as dist from torch import Tensor from torch.distributed import GradBucket @@ -68,7 +66,7 @@ def _allreduce_hook(gang: Gang, bucket: GradBucket) -> Future[Tensor]: ft = dist.all_reduce(bucket.buffer(), group=pg, async_op=True).get_future() - def return_reduced_bucket(f: Future[List[Tensor]]) -> Tensor: + def return_reduced_bucket(f: Future[list[Tensor]]) -> Tensor: output = f.value() # Skip division by the world size. diff --git a/src/fairseq2/nn/embedding.py b/src/fairseq2/nn/embedding.py index d5769539e..9b19ee092 100644 --- a/src/fairseq2/nn/embedding.py +++ b/src/fairseq2/nn/embedding.py @@ -7,7 +7,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Callable, Optional, final +from collections.abc import Callable +from typing import final import torch import torch.nn as nn @@ -15,11 +16,12 @@ from torch.nn import Module from torch.nn.functional import embedding from torch.nn.parameter import Parameter +from typing_extensions import override from fairseq2.gang import Gang from fairseq2.nn.utils.module import to_empty from fairseq2.tensor_parallel import gather, reduce, reduce_on_backward -from fairseq2.typing import META, DataType, Device, override +from fairseq2.typing import META, DataType, Device class Embedding(Module, ABC): @@ -27,11 +29,11 @@ class Embedding(Module, ABC): num_embeddings: int embedding_dim: int - pad_idx: Optional[int] - padding_idx: Optional[int] # Compat + pad_idx: int | None + padding_idx: int | None # Compat def __init__( - self, num_embeddings: int, embedding_dim: int, pad_idx: Optional[int] = None + self, num_embeddings: int, embedding_dim: int, pad_idx: int | None = None ) -> None: """ :param num_embeddings: @@ -79,17 +81,17 @@ class StandardEmbedding(Embedding): """Stores embeddings of a fixed dictionary and size in an in-memory table.""" weight: Parameter - init_fn: Optional[Callable[[StandardEmbedding], None]] + init_fn: Callable[[StandardEmbedding], None] | None def __init__( self, num_embeddings: int, embedding_dim: int, - pad_idx: Optional[int] = None, + pad_idx: int | None = None, *, - init_fn: Optional[Callable[[StandardEmbedding], None]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[StandardEmbedding], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param num_embeddings: @@ -150,7 +152,7 @@ class VocabShardedEmbedding(Embedding): gang: Gang sharded_num_embeddings: int weight: Parameter - init_fn: Optional[Callable[[StandardEmbedding], None]] + init_fn: Callable[[StandardEmbedding], None] | None @staticmethod def from_embedding(embed: StandardEmbedding, gang: Gang) -> VocabShardedEmbedding: @@ -190,11 +192,11 @@ def __init__( gang: Gang, num_embeddings: int, embedding_dim: int, - pad_idx: Optional[int] = None, + pad_idx: int | None = None, *, - init_fn: Optional[Callable[[StandardEmbedding], None]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[StandardEmbedding], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param gang: @@ -214,7 +216,7 @@ def __init__( if num_embeddings % gang.size != 0: raise ValueError( - f"`num_embeddings` must be divisible by `gang.size` ({gang.size}), but is {num_embeddings} instead." + f"`num_embeddings` must be a multiple of `gang.size` ({gang.size}), but is {num_embeddings} instead." ) self.gang = gang @@ -285,7 +287,7 @@ def forward(self, x: Tensor) -> Tensor: return x - def to_embedding(self, device: Optional[Device] = None) -> StandardEmbedding: + def to_embedding(self, device: Device | None = None) -> StandardEmbedding: """Convert this instance to a :class:`StandardEmbedding`.""" embed = self._embedding_like(META) @@ -330,7 +332,7 @@ class ShardedEmbedding(Embedding): gang: Gang sharded_embedding_dim: int weight: Parameter - init_fn: Optional[Callable[[StandardEmbedding], None]] + init_fn: Callable[[StandardEmbedding], None] | None @staticmethod def from_embedding(embed: StandardEmbedding, gang: Gang) -> ShardedEmbedding: @@ -370,11 +372,11 @@ def __init__( gang: Gang, num_embeddings: int, embedding_dim: int, - pad_idx: Optional[int] = None, + pad_idx: int | None = None, *, - init_fn: Optional[Callable[[StandardEmbedding], None]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[StandardEmbedding], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param gang: @@ -394,7 +396,7 @@ def __init__( if embedding_dim % gang.size != 0: raise ValueError( - f"`embedding_dim` must be divisible by `gang.size` ({gang.size}), but is {embedding_dim} instead." + f"`embedding_dim` must be a multiple of `gang.size` ({gang.size}), but is {embedding_dim} instead." ) self.gang = gang @@ -440,7 +442,7 @@ def forward(self, x: Tensor) -> Tensor: return x - def to_embedding(self, device: Optional[Device] = None) -> StandardEmbedding: + def to_embedding(self, device: Device | None = None) -> StandardEmbedding: """Convert this instance to a :class:`StandardEmbedding`.""" embed = self._embedding_like(META) diff --git a/src/fairseq2/nn/fsdp.py b/src/fairseq2/nn/fsdp.py index 706911425..0f143a406 100644 --- a/src/fairseq2/nn/fsdp.py +++ b/src/fairseq2/nn/fsdp.py @@ -6,50 +6,52 @@ from __future__ import annotations +import warnings +from collections.abc import Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Final, Iterator, Optional, Protocol, Sequence, Set, final +from typing import Any, Final, Protocol, final +from warnings import catch_warnings import torch from torch import Tensor +from torch.distributed import ProcessGroup from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision from torch.distributed.fsdp.api import ( BackwardPrefetch, CPUOffload, + MixedPrecision, ShardedOptimStateDictConfig, ShardedStateDictConfig, ShardingStrategy, StateDictType, ) -from torch.nn import Module +from torch.nn import Module, Parameter -from fairseq2.gang import Gang -from fairseq2.logging import get_log_writer +from fairseq2.gang import Gang, setup_hybrid_fsdp_gangs +from fairseq2.logging import log from fairseq2.nn.utils.module import ( + apply_to_parameters, infer_device, reset_non_persistent_buffers, reset_parameters, to_empty, ) from fairseq2.typing import DataType, Device -from fairseq2.utils.version import torch_greater_or_equal - -log = get_log_writer(__name__) def to_fsdp( module: Module, gang: Gang, - wrap_policy: Optional[FSDPWrapPolicy], + wrap_policy: FSDPWrapPolicy | None, *, - ignored_modules: Optional[Sequence[Module]] = None, + ignored_modules: Sequence[Module] | None = None, skip_init: bool = False, broadcast_state: bool = False, - memory_policy: Optional[FSDPMemoryPolicy] = None, + memory_policy: FSDPMemoryPolicy | None = None, reshard_after_forward: bool = True, - local_world_size: Optional[int] = None, - mixed_precision_dtype: Optional[DataType] = None, + local_world_size: int | None = None, + mixed_precision_dtype: DataType | None = None, fp32_reduce: bool = False, ) -> FSDP: """Wrap ``module`` with FSDP. @@ -64,9 +66,7 @@ def to_fsdp( :param ignored_param_names: The ignored parameter names. Can contain regular expressions. :param skip_init: - If ``True``, skips initializing the parameters and buffers moved from - the meta device onto the device of ``gang``. Only relevant if ``module`` - resides on the meta device. + Not used. :param broadcast_state: If ``True``, each FSDP module will broadcast its parameters and buffers from rank 0 to ensure that they are replicated across all processes. @@ -88,46 +88,49 @@ def to_fsdp( If ``True``, the gradients will be reduced in full precision. Only relevant if ``mixed_precision_dtype`` is not ``None``. """ + process_group: ProcessGroup | tuple[ProcessGroup, ProcessGroup] | None = None + if local_world_size is not None: - if local_world_size == 0: - raise ValueError( - f"`local_world_size` must be greater than 0, but is {local_world_size} instead." - ) - - if local_world_size > gang.size: - raise ValueError( - f"`local_world_size` must be less than or equal to `gang.size` ({gang.size}), but is {local_world_size} instead." - ) - - if gang.size % local_world_size != 0: - raise ValueError( - f"`gang.size` ({gang.size}) must be divisible by `local_world_size` ({local_world_size})." - ) - - # TODO(balioglu): Finish! - raise NotImplementedError("`local_world_size` is not supported yet.") + sharding_strategy = ShardingStrategy.HYBRID_SHARD + + sharding_gang, replication_gang = setup_hybrid_fsdp_gangs( + gang, local_world_size + ) + + process_group = ( + sharding_gang.as_process_group(), + replication_gang.as_process_group(), + ) else: if reshard_after_forward: sharding_strategy = ShardingStrategy.FULL_SHARD else: sharding_strategy = ShardingStrategy.SHARD_GRAD_OP + process_group = gang.as_process_group() + if memory_policy is None: memory_policy = FSDP_STANDARD_MEMORY_POLICY - param_init_fn = None + if skip_init: + log.warning("`skip_init` parameter has no effect and will be removed in a future release.") # fmt: skip - module_device = infer_device(module) - if module_device.type == "meta": - if not torch_greater_or_equal(2, 1): - log.warning("FSDP meta initialization is only supported on PyTorch 2.1.0 and later.") # fmt: skip + param_init_fn = None - to_empty(module, gang.device) + try: + module_device = infer_device(module) + except ValueError as ex: + raise ValueError( + "The device of `module` is not valid. See the nested exception for details." + ) from ex - if not broadcast_state: - reset_parameters(module) + if module_device.type == "meta": + if gang.rank == 0: + skip_init = not broadcast_state else: - param_init_fn = FSDPParameterInitializer(gang.device, skip_init) + skip_init = True + + param_init_fn = FSDPParameterInitializer(gang.device, skip_init) if mixed_precision_dtype is None: mp = None @@ -142,7 +145,7 @@ def to_fsdp( buffer_dtype=None, ) - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} # As of PyTorch 2.0, FSDP initialization fails in certain settings when an # empty `ignored_states` is specified (e.g. `sync_module_states` is set). @@ -151,7 +154,7 @@ def to_fsdp( fsdp = FSDP( module, - process_group=gang.as_process_group(), + process_group=process_group, sharding_strategy=sharding_strategy, cpu_offload=CPUOffload() if memory_policy.cpu_offload else None, auto_wrap_policy=wrap_policy, @@ -166,12 +169,15 @@ def to_fsdp( **kwargs, ) - FSDP.set_state_dict_type( - fsdp, - StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), - optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True), - ) + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + FSDP.set_state_dict_type( + fsdp, + StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), + optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True), + ) return fsdp @@ -200,7 +206,7 @@ def __call__(self, module: Module, recurse: bool, non_wrapped_numel: int) -> boo class FSDPMemoryPolicy: """Specifies the device memory usage policy of an FSDP module.""" - backward_prefetch: Optional[BackwardPrefetch] + backward_prefetch: BackwardPrefetch | None """The backward prefetch mode for all-gathers. For more information, check out the same named parameter of :class:`FSDP`.""" @@ -255,8 +261,8 @@ class FSDPParameterInitializer: ... ) """ - _module_memo: Set[Module] - _memo: Dict[Tensor, Tensor] + _module_memo: set[Module] + _memo: dict[Tensor, Tensor] _device: Device _skip_init: bool @@ -303,6 +309,8 @@ def summon_fsdp_for_validation(module: Module) -> Iterator[None]: if not isinstance(module, FSDP): yield else: + mp = module.mixed_precision or MixedPrecision() + # This is ugly, but our only option. We monkey-patch FSDP modules to # replace their `forward` methods with the wrapped `forward` methods. # Otherwise, FSDP fails to shard parameters at the end of the call. @@ -320,9 +328,19 @@ def enable_fsdp_forward(module_: Module) -> None: del m._fs2_backup_forward + def maybe_cast_dtype(t: Tensor) -> Tensor: + dtype = mp.param_dtype if isinstance(t, Parameter) else mp.buffer_dtype + + if dtype is None: + return t + + return t.to(dtype) + with FSDP.summon_full_params(module, writeback=False): disable_fsdp_forward(module) + apply_to_parameters(module, maybe_cast_dtype) + try: yield finally: diff --git a/src/fairseq2/nn/functional.py b/src/fairseq2/nn/functional.py index 8a6017bf5..18a02f3f0 100644 --- a/src/fairseq2/nn/functional.py +++ b/src/fairseq2/nn/functional.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Literal, Optional +from typing import Literal from torch import Tensor @@ -14,7 +14,7 @@ def nll_loss( lprobs: Tensor, targets: Tensor, - pad_idx: Optional[int], + pad_idx: int | None, *, label_smoothing: float = 0.0, reduction: Literal["none", "sum"] = "sum", diff --git a/src/fairseq2/nn/incremental_state.py b/src/fairseq2/nn/incremental_state.py index 1146ca77f..9b32af99e 100644 --- a/src/fairseq2/nn/incremental_state.py +++ b/src/fairseq2/nn/incremental_state.py @@ -7,7 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, Optional, Type, TypeVar, final +from typing import TypeVar, final from torch import Tensor from torch.nn import Module @@ -54,11 +54,11 @@ class IncrementalStateBag: _step_nr: int _max_num_steps: int - _capacity_increment: Optional[int] - _module_states: Dict[Module, IncrementalState] + _capacity_increment: int | None + _module_states: dict[Module, IncrementalState] def __init__( - self, max_num_steps: int, *, capacity_increment: Optional[int] = 16 + self, max_num_steps: int, *, capacity_increment: int | None = 16 ) -> None: """ :param max_num_steps: @@ -97,7 +97,7 @@ def increment_step_nr(self, value: int = 1) -> None: self._step_nr = step_nr - def get_state(self, m: Module, kls: Type[T]) -> Optional[T]: + def get_state(self, m: Module, kls: type[T]) -> T | None: """Get the state of ``m`` if present in the bag. :param m: @@ -144,7 +144,7 @@ def max_num_steps(self) -> int: return self._max_num_steps @property - def capacity_increment(self) -> Optional[int]: + def capacity_increment(self) -> int | None: """The sequence length capacity of state tensors will be incremented by multiples of this value.""" return self._capacity_increment diff --git a/src/fairseq2/nn/lora.py b/src/fairseq2/nn/lora.py index 5a802ec45..7993d3ec3 100644 --- a/src/fairseq2/nn/lora.py +++ b/src/fairseq2/nn/lora.py @@ -10,7 +10,7 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional, final +from typing import Literal, final import torch import torch.nn as nn @@ -29,7 +29,7 @@ class LoRAConfig: r: int alpha: float dropout_p: float - keys: List[str] + keys: list[str] class LoRALayer(ABC): @@ -65,8 +65,8 @@ def __init__( self, wrapped: Embedding, config: LoRAConfig, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param wrapped: @@ -142,10 +142,10 @@ def unmerge(self) -> None: class LoRALinear(Projection, LoRALayer): wrapped: Projection weight: Parameter - bias: Optional[Parameter] + bias: Parameter | None lora_A: Parameter lora_B: Parameter - dropout: Optional[Dropout] + dropout: Dropout | None skip_init: bool merged: bool @@ -154,8 +154,8 @@ def __init__( wrapped: Projection, config: LoRAConfig, skip_init: bool = False, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param wrapped: @@ -256,7 +256,7 @@ def wrap_lora( parent = module.get_submodule(".".join(submodule_path[:-1])) submodule_name = submodule_path[-1] - lora_layer: Optional[LoRALayer] = None + lora_layer: LoRALayer | None = None if isinstance(submodule, Projection): lora_layer = LoRALinear( wrapped=submodule, @@ -299,7 +299,7 @@ def unwrap_lora(module: nn.Module, merge: bool = True) -> nn.Module: parent = module.get_submodule(".".join(submodule_path[:-1])) submodule_name = submodule_path[-1] - unwrapped_layer: Optional[nn.Module] = None + unwrapped_layer: nn.Module | None = None if isinstance(submodule, LoRALayer): unwrapped_layer = submodule.wrapped_module else: @@ -326,7 +326,7 @@ def unmerge_lora(module: nn.Module) -> None: submodule.unmerge() -def lora_state_dict(module: nn.Module) -> Dict[str, Any]: +def lora_state_dict(module: nn.Module) -> dict[str, object]: lora_names = [] for name, submodule in module.named_modules(): if isinstance(submodule, LoRALayer): @@ -359,6 +359,6 @@ def freeze_non_lora( param.requires_grad = False -def _is_target_module(name: str, target_keys: List[str]) -> bool: +def _is_target_module(name: str, target_keys: list[str]) -> bool: # Check if the `name` matches any of the `target_keys``. return any(name == key or re.match(key, name) for key in target_keys) diff --git a/src/fairseq2/nn/normalization.py b/src/fairseq2/nn/normalization.py index 603027c1d..c371bc156 100644 --- a/src/fairseq2/nn/normalization.py +++ b/src/fairseq2/nn/normalization.py @@ -7,13 +7,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, Sequence, Tuple, Union, final +from collections.abc import Callable, Sequence +from typing import Any, Literal, final import torch import torch.nn as nn -from torch import Tensor +from torch import Size, Tensor from torch.nn import Module, Parameter from torch.nn.functional import layer_norm +from typing_extensions import override try: from apex.normalization.fused_layer_norm import ( # type: ignore[import] @@ -25,27 +27,30 @@ except ImportError: _has_apex = False -from fairseq2.typing import DataType, Device, override +from fairseq2.error import NotSupportedError +from fairseq2.typing import DataType, Device class LayerNorm(Module, ABC): """Applies Layer Normalization to incoming data.""" - normalized_shape: Tuple[int, ...] + normalized_shape: tuple[int, ...] eps: float elementwise_affine: bool - weight: Optional[Parameter] - bias: Optional[Parameter] + weight: Parameter | None + bias: Parameter | None + init_fn: Callable[[LayerNorm], None] | None def __init__( self, - normalized_shape: Union[int, Sequence[int], torch.Size], + normalized_shape: int | Sequence[int] | Size, bias: bool, *, eps: float = 1e-5, elementwise_affine: bool = True, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[LayerNorm], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param normalized_shape: @@ -85,15 +90,20 @@ def __init__( else: self.register_parameter("bias", None) + self.init_fn = init_fn + self.reset_parameters() def reset_parameters(self) -> None: """Reset the parameters and buffers of the module.""" - if self.weight is not None: - nn.init.ones_(self.weight) + if self.init_fn is not None: + self.init_fn(self) + else: + if self.weight is not None: + nn.init.ones_(self.weight) - if self.bias is not None: - nn.init.zeros_(self.bias) + if self.bias is not None: + nn.init.zeros_(self.bias) @abstractmethod def forward(self, x: Tensor) -> Tensor: @@ -149,14 +159,18 @@ def __init__( if impl == "apex": if not _has_apex: - raise RuntimeError( + raise NotSupportedError( "`impl` is 'apex', but no APEX installation can be found." ) if self.bias is not None: - raise RuntimeError( + raise NotSupportedError( "`impl` is 'apex', but APEX does not support the `bias` parameter." ) + elif impl != "py": + raise ValueError( + f"`impl` must be 'auto', 'py', or 'apex', but is '{impl}' instead." + ) self._impl = impl diff --git a/src/fairseq2/nn/padding.py b/src/fairseq2/nn/padding.py index 120a1488f..643835d6f 100644 --- a/src/fairseq2/nn/padding.py +++ b/src/fairseq2/nn/padding.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, Tuple, cast, final +from collections.abc import Sequence +from typing import cast, final import torch from torch import Tensor @@ -21,8 +22,8 @@ class PaddingMask: _seq_lens: Tensor _batch_seq_len: int - _materialized: Optional[Tensor] - _materialized_float: Optional[Tensor] + _materialized: Tensor | None + _materialized_float: Tensor | None def __init__(self, seq_lens: Tensor, batch_seq_len: int) -> None: """ @@ -109,7 +110,7 @@ def to_padding_mask(seq_lens: Tensor, batch_seq_len: int) -> Tensor: return indices < lengths -def get_seq_lens(seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor: +def get_seq_lens(seqs: Tensor, padding_mask: PaddingMask | None) -> Tensor: """Retrieve the sequence lengths of ``seqs``. :param seqs: @@ -132,7 +133,7 @@ def get_seq_lens(seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor: def apply_padding_mask( - seqs: Tensor, padding_mask: Optional[PaddingMask], pad_value: Any = 0 + seqs: Tensor, padding_mask: PaddingMask | None, pad_value: int | float | Tensor = 0 ) -> Tensor: """Apply the specified padding mask to ``seqs``. @@ -161,8 +162,8 @@ def apply_padding_mask( def get_seqs_and_padding_mask( - data: SequenceData, device: Optional[Device] = None -) -> Tuple[Tensor, Optional[PaddingMask]]: + data: SequenceData, device: Device | None = None +) -> tuple[Tensor, PaddingMask | None]: """Return the sequences along with their padding mask from ``data``. :returns: @@ -187,7 +188,7 @@ def get_seqs_and_padding_mask( def pad_seqs( seqs: Sequence[Tensor], pad_value: int = 0, pad_to_multiple: int = 1 -) -> Tuple[Tensor, Optional[PaddingMask]]: +) -> tuple[Tensor, PaddingMask | None]: """Stack ``seqs`` along a new batch dimension and pad them to equal length. :param seqs: diff --git a/src/fairseq2/nn/position_encoder.py b/src/fairseq2/nn/position_encoder.py index 9aeb987c3..1340496e3 100644 --- a/src/fairseq2/nn/position_encoder.py +++ b/src/fairseq2/nn/position_encoder.py @@ -8,33 +8,38 @@ import math from abc import ABC, abstractmethod -from typing import Callable, Optional, final +from collections.abc import Callable +from typing import final import torch import torch.nn as nn from torch import Tensor from torch.nn import Module -from torch.nn.functional import embedding +from torch.nn.functional import embedding, interpolate from torch.nn.parameter import Parameter +from typing_extensions import override +from fairseq2.error import InternalError from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask -from fairseq2.typing import DataType, Device, override -from fairseq2.utils.version import torch_greater_or_equal +from fairseq2.typing import DataType, Device class PositionEncoder(Module, ABC): """Encodes sequences with positional information.""" encoding_dim: int - max_seq_len: Optional[int] + max_seq_len: int | None - def __init__(self, encoding_dim: int, max_seq_len: Optional[int]) -> None: + def __init__(self, encoding_dim: int, max_seq_len: int | None) -> None: """ - :param encoding_dim: - The dimensionality of positional encodings. - :param max_seq_len: - The maximum sequence length. + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of input sequences is expected to have the same + dimensionality. + :param max_seq_len: The maximum allowed length for input sequences. + Sequences longer than ``max_seq_len`` will cause a :class:`ValueError`. + Typically it is set to the context length of the underlying model. + If ``None``, sequences can have arbitrary length. """ super().__init__() @@ -44,26 +49,30 @@ def __init__(self, encoding_dim: int, max_seq_len: Optional[int]) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, *, - state_bag: Optional[IncrementalStateBag] = None, + state_bag: IncrementalStateBag | None = None, ) -> Tensor: """ - :param seqs: - The sequences to encode with positional information. *Shape:* - :math:`(*,S,E)`, where :math:`*` is any number of batch dimensions - including none, :math:`S` is the sequence length, and :math:`E` is - the dimensionality of the positional encodings. - :param padding_mask: - The padding mask of ``seqs``. *Shape:* :math:`(*,S)`, where :math:`*` - is any number of batch dimensions including none and :math:`S` is - the sequence length. - :param state_bag: - The state bag to use for incremental decoding. - - :returns: - The input sequences with positional information encoded. *Shape:* - Same as ``seqs``. + Returns a copy of ``seqs`` with positional information encoded. + + :param seqs: The input sequences to encode. *Shape:* :math:`(*,S,E)`, + where :math:`*` is any number of batch dimensions including none, + :math:`S` is the sequence length, and :math:`E` is the dimensionality + of the positional encodings. + :param padding_mask: The padding mask of ``seqs``. *Shape:* :math:`(*,S)`, + where :math:`*` is any number of batch dimensions including none and + :math:`S` is the sequence length. + :param state_bag: If not ``None``, the encoder will operate in + incremental decoding mode. This means that the first step in ``seqs`` + will be considered to be at position :attr:`state_bag.step_nr + ` instead of 0. + + :raises ValueError: when the sequence length of ``seqs`` exceeds + :attr:`max_seq_len`. + + :returns: The input sequences with positional information encoded. + *Shape:* Same as ``seqs``. """ if self.max_seq_len is not None: if self.training or state_bag is None: @@ -82,25 +91,12 @@ def forward( def _do_forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: """ - :param seqs: - The sequences to encode with positional information. *Shape:* - :math:`(*,S,E)`, where :math:`*` is any number of batch dimensions - including none, :math:`S` is the sequence length, and :math:`E` is - the dimensionality of the positional encodings. - :param padding_mask: - The padding mask of ``seqs``. *Shape:* :math:`(*,S)`, where :math:`*` - is any number of batch dimensions including none and :math:`S` is - the sequence length. - :param state_bag: - The state bag to use for incremental decoding. - - :returns: - The input sequences with positional information encoded. *Shape:* - Same as ``seqs``. + When overriden in a subclass, returns a copy of ``seqs`` with positional + information encoded. See :meth:`forward` for parameter descriptions. :meta public: """ @@ -117,42 +113,7 @@ def extra_repr(self) -> str: @final class SinusoidalPositionEncoder(PositionEncoder): - """Encodes sequences with fixed sinusoidal positional information. - - The positional encodings are initialized as in tensor2tensor which differs - slightly from the description in section 3.5 of - :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. This means instead of - - .. math:: - PE_{(pos, 2i)} = \\text{sin}(pos/10000^{2i/d_{model}}) - - PE_{(pos, 2i+1)} = \\text{cos}(pos/10000^{2i/d_{model}}) - - we use - - .. math:: - PE_{(pos, i)} = \\text{sin}(pos/10000^{i/d_{model}})\\;\\text{for}\\;i\\; <\\frac{d_{model}}{2} - - PE_{(pos, i)} = \\text{cos}(pos/10000^{i/d_{model}})\\;\\text{for}\\;i\\;\\geq\\frac{d_{model}}{2} - - See `here `_ for more - information. - - Usage: - - >>> import torch - >>> - >>> from fairseq2.nn.position_encoder import SinusoidalPositionEncoder - >>> - >>> m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=16) - >>> - >>> seqs = torch.ones((3, 4)) - >>> - >>> m(seqs) - tensor([[ 1.0000e+00, 1.0000e+00, 2.0000e+00, 2.0000e+00], # pos 0 - [ 9.4147e-01, 2.0000e-04, 6.4030e-01, 2.0000e+00], # pos 1 - [ 1.0930e-02, 3.0000e-04, -5.1615e-01, 2.0000e+00]]) # pos 2 - """ + """Encodes sequences with fixed sinusoidal positional information.""" freqs: Tensor @@ -161,9 +122,18 @@ def __init__( encoding_dim: int, max_seq_len: int, *, - _legacy_pad_idx: Optional[int] = None, - device: Optional[Device] = None, + _legacy_pad_idx: int | None = None, + device: Device | None = None, ) -> None: + """ + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of input sequences is expected to have the same + dimensionality. + :param max_seq_len: The maximum allowed length for input sequences. + Sequences longer than ``max_seq_len`` will cause a :class:`ValueError`. + + :raise ValueError: when ``encoding_dim`` is not even. + """ super().__init__(encoding_dim, max_seq_len) if encoding_dim % 2 != 0: @@ -171,6 +141,12 @@ def __init__( f"`encoding_dim` must be even, but is {encoding_dim} instead." ) + freqs = torch.empty( + (max_seq_len, encoding_dim), device=device, dtype=torch.float32 + ) + + self.register_buffer("freqs", freqs, persistent=False) + # This is a legacy parameter that should only be set when the encodings # must be compatible with fairseq. if _legacy_pad_idx is None: @@ -178,12 +154,6 @@ def __init__( else: self._sin_offset = 1 + _legacy_pad_idx - freqs = torch.empty( - (max_seq_len, encoding_dim), device=device, dtype=torch.float32 - ) - - self.register_buffer("freqs", freqs, persistent=False) - self.reset_parameters() def reset_parameters(self) -> None: @@ -192,15 +162,11 @@ def reset_parameters(self) -> None: def reset_non_persistent_buffers(self) -> None: """Reset the non-persistent buffers of the module.""" - assert self.max_seq_len is not None + if self.max_seq_len is None: + raise InternalError("`max_seq_len` is `None`.") device, dtype = self.freqs.device, self.freqs.dtype - num_sin = self.encoding_dim // 2 - - l_half = self.freqs[:, :num_sin] - r_half = self.freqs[:, num_sin:] - start_step = self._sin_offset # (S) @@ -208,26 +174,14 @@ def reset_non_persistent_buffers(self) -> None: start_step, start_step + self.max_seq_len, device=device, dtype=dtype ) - # (E) - indices = torch.arange(num_sin, device=device, dtype=dtype) - - # This is identical to tensor2tensor's implementation. - freqs = torch.exp(indices * -math.log(10000.0) / (num_sin - 1)) - - # (S) x (E) -> (S, E) - torch.outer(steps, freqs, out=l_half) - - r_half.copy_(l_half) - - l_half.sin_() - r_half.cos_() + _fill_sin_freq_table(self.freqs, self.encoding_dim, steps) @override def _do_forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: """:meta private:""" seq_len = seqs.size(-2) @@ -242,25 +196,38 @@ def _do_forward( return fp32_seqs.type_as(seqs) +def _fill_sin_freq_table( + freqs: Tensor, encoding_dim: int, steps: Tensor, correction: int = 1 +) -> None: + freqs = freqs.flatten(0, -2) + + num_sin = encoding_dim // 2 + + l_half = freqs[:, :num_sin] + r_half = freqs[:, num_sin:] + + # (E) + indices = torch.arange(num_sin, device=steps.device, dtype=steps.dtype) + + # This is identical to tensor2tensor's implementation. + freqs = torch.exp(indices * -math.log(10000.0) / (num_sin - correction)) + + # (S) x (E) -> (S, E) + torch.outer(steps, freqs, out=l_half) + + # The cosine frequencies might be truncated if the table is shorter than the + # encoding dimension due to rounding. + r_dim = r_half.size(1) + + r_half.copy_(l_half[:, :r_dim]) + + l_half.sin_() + r_half.cos_() + + @final class LearnedPositionEncoder(PositionEncoder): - """Encodes sequences with learned positional embeddings. - - Usage: - - >>> import torch - >>> - >>> from fairseq2.nn.position_encoder import LearnedPositionEncoder - >>> - >>> m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=16) - >>> - >>> seqs = torch.ones((3, 4)) - >>> - >>> m(seqs) - tensor([[ 1.1135, 0.5548, 0.4293, 2.0112], # pos 0 - [ 0.2364, 0.6009, 3.3865, -2.4810], # pos 1 - [-0.4746, 0.4544, 0.2761, 0.8828]], grad_fn=) # pos 2 - """ + """Encodes sequences with learned positional embeddings.""" weight: Parameter @@ -269,9 +236,16 @@ def __init__( encoding_dim: int, max_seq_len: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: + """ + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of input sequences is expected to have the same + dimensionality. + :param max_seq_len: The maximum allowed length for input sequences. + Sequences longer than ``max_seq_len`` will cause a :class:`ValueError`. + """ super().__init__(encoding_dim, max_seq_len) self.weight = Parameter( @@ -288,8 +262,8 @@ def reset_parameters(self) -> None: def _do_forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: """:meta private:""" seq_len = seqs.size(-2) @@ -308,12 +282,14 @@ def _do_forward( @final class RotaryEncoder(PositionEncoder): - """Encodes sequences with relative positional information as described in - :cite:t:`https://doi.org/10.48550/arxiv.2104.09864`.""" + """ + Encodes sequences with relative positional information as described in + :cite:t:`https://doi.org/10.48550/arxiv.2104.09864`. + """ freqs: Tensor theta: float - freqs_init_fn: Optional[Callable[[RotaryEncoder], Tensor]] + freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None def __init__( self, @@ -321,15 +297,24 @@ def __init__( max_seq_len: int, *, theta: float = 10_000.0, - freqs_init_fn: Optional[Callable[[RotaryEncoder], Tensor]] = None, - device: Optional[Device] = None, + freqs_init_fn: Callable[[RotaryEncoder], Tensor] | None = None, + device: Device | None = None, ) -> None: """ - :param theta: - The coefficient of the long-term decay as described in section 3.3 - of the reference paper. - :param freqs_init_fn: - The callable to initialize the frequency table. + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of input sequences is expected to have the same + dimensionality. + :param max_seq_len: The maximum allowed length for input sequences. + Sequences longer than ``max_seq_len`` will cause a :class:`ValueError`. + :param theta: The coefficient of the long-term decay as described in + section 3.3 of the reference paper. + :param freqs_init_fn: A callable to initialize the frequency table. The + encoder will be passed to the callable as an argument and it is + expected for the callable to return a :class:`~torch.Tensor` holding + the frequency table. If ``None``, the frequencies will be initialized + as described in the reference paper. + + :raise ValueError: when ``encoding_dim`` is not even. """ super().__init__(encoding_dim, max_seq_len) @@ -355,15 +340,11 @@ def reset_parameters(self) -> None: def reset_non_persistent_buffers(self) -> None: """Reset the non-persistent buffers of the module.""" - assert self.max_seq_len is not None + if self.max_seq_len is None: + raise InternalError("`max_seq_len` is `None`.") device = self.freqs.device - # In PyTorch 2.0 and 2.1, `torch.polar` does not support meta device. - if not torch_greater_or_equal(2, 2): - if device.type == "meta": - return - complex_freqs = torch.view_as_complex(self.freqs) # (S) @@ -389,8 +370,8 @@ def reset_non_persistent_buffers(self) -> None: def _do_forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: """:meta private:""" seq_len = seqs.size(-2) @@ -415,3 +396,325 @@ def _do_forward( fp32_seqs = torch.view_as_real(complex_seqs).flatten(-2) return fp32_seqs.type_as(seqs) + + +class InterpolatedPositionEncoder(Module, ABC): + """Encodes N-dimensional inputs with interpolated positional information.""" + + encoding_dim: int + + def __init__(self, encoding_dim: int) -> None: + """ + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of inputs is expected to have the same dimensionality. + """ + super().__init__() + + self.encoding_dim = encoding_dim + + @abstractmethod + def forward(self, x: Tensor) -> Tensor: + """ + Returns a copy of ``x`` with positional information encoded. + + :params x: The inputs to encode. *Shape:* :math:`(N,*,E)`, where + :math:`N` is the batch size, :math:`*` is any number of + implementation-specific dimensions, and :math:`E` is the + dimensionality of the positional encodings. + + :returns: The inputs with positional information encoded. *Shape:* Same + as ``x``. + """ + + def extra_repr(self) -> str: + """:meta private:""" + return f"encoding_dim={self.encoding_dim}" + + +class SinusoidalNdPositionEncoder(InterpolatedPositionEncoder): + """ + Provides a skeletal implementation of interpolated sinusoidal position + encoders. + """ + + freqs: Tensor + grid_dims: tuple[int, ...] + + def __init__( + self, + encoding_dim: int, + grid_dims: tuple[int, ...], + *, + device: Device | None = None, + ) -> None: + """ + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of inputs is expected to have the same dimensionality. + :param grid_dims: The dimensionality of the frequency table. + """ + super().__init__(encoding_dim) + + if encoding_dim % 2 != 0: + raise ValueError( + f"`encoding_dim` must be even, but is {encoding_dim} instead." + ) + + freqs = torch.empty( + grid_dims + (encoding_dim,), device=device, dtype=torch.float32 + ) + + self.grid_dims = grid_dims + + self.register_buffer("freqs", freqs, persistent=False) + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + self.reset_non_persistent_buffers() + + @abstractmethod + def reset_non_persistent_buffers(self) -> None: + """Reset the non-persistent buffers of the module.""" + + @override + def forward(self, x: Tensor) -> Tensor: + freqs = self._interpolate_freqs_as(x) + + fp32_x = x.float() + freqs + + return fp32_x.type_as(x) + + @abstractmethod + def _interpolate_freqs_as(self, x: Tensor) -> Tensor: + """ + Interpolates (or extrapolates) the frequency table to the dimensionality + of ``x``. + + :params x: The inputs to encode. *Shape:* :math:`(N,*,E)`, where + :math:`N` is the batch size, :math:`*` is the same number of + dimensions as :attr:`grid_dims`, but potentially with different + dimensionality, and :math:`E` is the dimensionality of the + positional encodings. + + :returns: The interpolated (or extrapolated) frequency table. *Shape:* + Same as ``x``. + """ + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + return f"{s}, grid_dims={self.grid_dims}" + + +class Sinusoidal2dPositionEncoder(SinusoidalNdPositionEncoder): + """Encodes 2-dimensional inputs with sinusoidal positional information. + + .. note:: + This implementation uses bicubic interpolation. The interpolation + technique can be changed by subclassing this type and overriding the + :meth:`_interpolate_freqs_as` method. + """ + + def __init__( + self, + encoding_dim: int, + grid_dims: tuple[int, int], + *, + device: Device | None = None, + ) -> None: + """ + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of inputs is expected to have the same dimensionality. + :param grid_dims: The dimensionality of the depth, height, and width + dimensions. + """ + super().__init__(encoding_dim, grid_dims, device=device) + + self.reset_parameters() + + @override + def reset_non_persistent_buffers(self) -> None: + freqs = self.freqs + + device, dtype = freqs.device, freqs.dtype + + h, w = freqs.shape[:-1] + + h_steps = torch.arange(h, device=device, dtype=dtype) + w_steps = torch.arange(w, device=device, dtype=dtype) + + h_coords, w_coords = torch.meshgrid(h_steps, w_steps, indexing="ij") + + h_coords = h_coords.flatten() + w_coords = w_coords.flatten() + + uniform_dim = math.ceil(self.encoding_dim / 4) * 2 + + h_dim = uniform_dim + w_dim = uniform_dim + + idx = 0 + + _fill_sin_freq_table( + freqs[..., idx : idx + h_dim], h_dim, h_coords, correction=0 + ) + + idx = h_dim + + _fill_sin_freq_table( + freqs[..., idx : idx + w_dim], w_dim, w_coords, correction=0 + ) + + @override + def _interpolate_freqs_as(self, x: Tensor) -> Tensor: + freqs = self.freqs + + if x.ndim != 4: + raise ValueError( + f"`x` must be 4 dimensional, but is {x.ndim} dimensional instead." + ) + + frq_dims, inp_dims = freqs.shape[:-1], x.shape[1:-1] + + if frq_dims == inp_dims: + return freqs + + frq_h, frq_w = frq_dims + inp_h, inp_w = inp_dims + + scale_factor = math.sqrt((inp_h * inp_w) / (frq_h * frq_w)) + + # (H_frq, W_frq, E) -> (1, H_frq, W_frq, E) + freqs = freqs.unsqueeze(0) + + # (1, H_frq, W_frq, E) -> (1, E, H_frq, W_frq) + freqs = freqs.permute(0, 3, 1, 2) + + # (1, E, H_frq, W_frq) -> (1, E, H_inp, W_inp) + freqs = interpolate(freqs, scale_factor=scale_factor, mode="bicubic") + + # (1, E, H_inp, W_inp) -> (1, H_inp, W_inp, E) + freqs = freqs.permute(0, 2, 3, 1) + + # (1, H_inp, W_inp, E) -> (H_inp, W_inp, E) + return freqs.squeeze(0) + + +class Sinusoidal3dPositionEncoder(SinusoidalNdPositionEncoder): + """Encodes 3-dimensional inputs with sinusoidal positional information. + + .. note:: + This implementation uses trilinear interpolation. The interpolation + technique can be changed by subclassing this type and overriding the + :meth:`_interpolate_freqs_as` method. + """ + + uniform_power: bool + + def __init__( + self, + encoding_dim: int, + grid_dims: tuple[int, int, int], + *, + uniform_power: bool = False, + device: Device | None = None, + ) -> None: + """ + :param encoding_dim: The dimensionality of positional encodings. The + last dimension of inputs is expected to have the same dimensionality. + :param grid_dims: The dimensionality of the depth, height, and width + dimensions. + :param uniform_power: If ``True``, each dimension of ``grid_dims`` will + have equal representation in the produced positional encodings. This + means, if ``True``, a positional encoding will consists of 1/3 depth, + 1/3 height, and 1/3 width information; otherwise, 1/2 depth, 1/4 + height, and 1/4 width information. + """ + super().__init__(encoding_dim, grid_dims, device=device) + + self.uniform_power = uniform_power + + self.reset_parameters() + + @override + def reset_non_persistent_buffers(self) -> None: + freqs = self.freqs + + device, dtype = freqs.device, freqs.dtype + + d, h, w = freqs.shape[:-1] + + d_steps = torch.arange(d, device=device, dtype=dtype) + h_steps = torch.arange(h, device=device, dtype=dtype) + w_steps = torch.arange(w, device=device, dtype=dtype) + + d_coords, h_coords, w_coords = torch.meshgrid( + d_steps, h_steps, w_steps, indexing="ij" + ) + + d_coords = d_coords.flatten() + h_coords = h_coords.flatten() + w_coords = w_coords.flatten() + + if self.uniform_power: + uniform_dim = math.ceil(self.encoding_dim / 6) * 2 + + d_dim = uniform_dim + h_dim = uniform_dim + w_dim = uniform_dim + else: + d_dim = math.ceil(self.encoding_dim / 4) * 2 + h_dim = math.ceil(self.encoding_dim / 8) * 2 + w_dim = math.ceil(self.encoding_dim / 8) * 2 + + idx = 0 + + _fill_sin_freq_table( + freqs[..., idx : idx + d_dim], d_dim, d_coords, correction=0 + ) + + idx = d_dim + + _fill_sin_freq_table( + freqs[..., idx : idx + h_dim], h_dim, h_coords, correction=0 + ) + + idx = d_dim + h_dim + + _fill_sin_freq_table( + freqs[..., idx : idx + w_dim], w_dim, w_coords, correction=0 + ) + + @override + def _interpolate_freqs_as(self, x: Tensor) -> Tensor: + freqs = self.freqs + + if x.ndim != 5: + raise ValueError( + f"`x` must be 5 dimensional, but is {x.ndim} dimensional instead." + ) + + frq_dims, inp_dims = freqs.shape[:-1], x.shape[1:-1] + + if frq_dims == inp_dims: + return freqs + + frq_d, frq_h, frq_w = frq_dims + inp_d, inp_h, inp_w = inp_dims + + scale_factor = (inp_d / frq_d, inp_h / frq_h, inp_w / frq_w) + + # (D_frq, H_frq, W_frq, E) -> (1, D_frq, H_frq, W_frq, E) + freqs = freqs.unsqueeze(0) + + # (1, D_frq, H_frq, W_frq, E) -> (1, E, D_frq, H_frq, W_frq) + freqs = freqs.permute(0, 4, 1, 2, 3) + + # (1, E, D_frq, H_frq, W_frq) -> (1, E, D_inp, H_inp, W_inp) + freqs = interpolate(freqs, scale_factor=scale_factor, mode="trilinear") + + # (1, E, D_inp, H_inp, W_inp) -> (1, D_inp, H_inp, W_inp, E) + freqs = freqs.permute(0, 2, 3, 4, 1) + + # (1, D_inp, H_inp, W_inp, E) -> (D_inp, H_inp, W_inp, E) + return freqs.squeeze(0) diff --git a/src/fairseq2/nn/projection.py b/src/fairseq2/nn/projection.py index 2801ed7a0..4c25b95df 100644 --- a/src/fairseq2/nn/projection.py +++ b/src/fairseq2/nn/projection.py @@ -8,7 +8,8 @@ import math from abc import ABC, abstractmethod -from typing import Callable, Optional, final +from collections.abc import Callable +from typing import final import torch import torch.nn as nn @@ -16,11 +17,13 @@ from torch.nn import Module from torch.nn.functional import linear from torch.nn.parameter import Parameter +from typing_extensions import override +from fairseq2.error import InternalError from fairseq2.gang import Gang from fairseq2.nn.utils.module import to_empty from fairseq2.tensor_parallel import gather, reduce, reduce_on_backward, scatter -from fairseq2.typing import META, DataType, Device, override +from fairseq2.typing import META, DataType, Device class Projection(Module, ABC): @@ -71,8 +74,8 @@ class Linear(Projection): """ weight: Parameter - bias: Optional[Parameter] - init_fn: Optional[Callable[[Linear], None]] + bias: Parameter | None + init_fn: Callable[[Linear], None] | None def __init__( self, @@ -80,9 +83,9 @@ def __init__( output_dim: int, bias: bool, *, - init_fn: Optional[Callable[[Linear], None]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[Linear], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param input_dim: @@ -116,7 +119,7 @@ def reset_parameters(self) -> None: if self.init_fn is not None: self.init_fn(self) else: - _init_uniform_weight_and_bias(self.weight, self.bias) + _init_uniform(self.weight, self.bias) @override def forward(self, x: Tensor) -> Tensor: @@ -143,9 +146,9 @@ class ColumnShardedLinear(Projection): gang: Gang sharded_output_dim: int weight: Parameter - bias: Optional[Parameter] + bias: Parameter | None gather_output: bool - init_fn: Optional[Callable[[Linear], None]] + init_fn: Callable[[Linear], None] | None @staticmethod def from_linear( @@ -193,9 +196,9 @@ def __init__( bias: bool, *, gather_output: bool = True, - init_fn: Optional[Callable[[Linear], None]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[Linear], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param gang: @@ -215,7 +218,7 @@ def __init__( if output_dim % gang.size != 0: raise ValueError( - f"`output_dim` must be divisible by `gang.size` ({gang.size}), but is {output_dim} instead." + f"`output_dim` must be a multiple of `gang.size` ({gang.size}), but is {output_dim} instead." ) self.gang = gang @@ -261,7 +264,8 @@ def _copy_weight(self, linear: Linear) -> None: self.weight.copy_(weight_shards[self.gang.rank]) if self.bias is not None: - assert linear.bias is not None + if linear.bias is None: + raise InternalError("`linear.bias` is `None`.") with torch.no_grad(): bias_shards = linear.bias.split(self.sharded_output_dim) @@ -279,7 +283,7 @@ def forward(self, x: Tensor) -> Tensor: return x - def to_linear(self, device: Optional[Device] = None) -> Linear: + def to_linear(self, device: Device | None = None) -> Linear: """Convert this instance to a :class:`Linear`.""" linear = self._linear_like(META) @@ -291,7 +295,8 @@ def to_linear(self, device: Optional[Device] = None) -> Linear: linear.weight.copy_(weight) if self.bias is not None: - assert linear.bias is not None + if linear.bias is None: + raise InternalError("`linear.bias` is `None`.") with torch.no_grad(): bias = gather(self.bias, self.gang, dim=0) @@ -342,9 +347,9 @@ class RowShardedLinear(Projection): gang: Gang sharded_input_dim: int weight: Parameter - bias: Optional[Parameter] + bias: Parameter | None scatter_input: bool - init_fn: Optional[Callable[[Linear], None]] + init_fn: Callable[[Linear], None] | None @staticmethod def from_linear( @@ -393,9 +398,9 @@ def __init__( bias: bool, *, scatter_input: bool = True, - init_fn: Optional[Callable[[Linear], None]] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + init_fn: Callable[[Linear], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param gang: @@ -416,7 +421,7 @@ def __init__( if input_dim % gang.size != 0: raise ValueError( - f"`input_dim` must be divisible by `gang.size` ({gang.size}), but is {input_dim} instead." + f"`input_dim` must be a multiple of `gang.size` ({gang.size}), but is {input_dim} instead." ) self.gang = gang @@ -462,7 +467,8 @@ def _copy_weight(self, linear: Linear) -> None: self.weight.copy_(weight_shards[self.gang.rank]) if self.bias is not None: - assert linear.bias is not None + if linear.bias is None: + raise InternalError("`linear.bias` is `None`.") with torch.no_grad(): self.bias.copy_(linear.bias) @@ -481,7 +487,7 @@ def forward(self, x: Tensor) -> Tensor: return x - def to_linear(self, device: Optional[Device] = None) -> Linear: + def to_linear(self, device: Device | None = None) -> Linear: """Convert this instance to a :class:`Linear`.""" linear = self._linear_like(META) @@ -493,7 +499,8 @@ def to_linear(self, device: Optional[Device] = None) -> Linear: linear.weight.copy_(weight) if self.bias is not None: - assert linear.bias is not None + if linear.bias is None: + raise InternalError("`linear.bias` is `None`.") with torch.no_grad(): linear.bias.copy_(self.bias) @@ -540,9 +547,9 @@ class TiedProjection(Projection): bias of another :class:`~torch.nn.Module` instance.""" weight: Parameter - bias: Optional[Parameter] + bias: Parameter | None - def __init__(self, weight: Parameter, bias: Optional[Parameter]) -> None: + def __init__(self, weight: Parameter, bias: Parameter | None) -> None: """ :param weight: The shared weights. @@ -559,7 +566,26 @@ def forward(self, x: Tensor) -> Tensor: return linear(x, self.weight, self.bias) -def _init_uniform_weight_and_bias(weight: Tensor, bias: Optional[Tensor]) -> None: +@final +class IdentityProjection(Projection): + """ + Used to disable a projection layer without changing the module architecture. + """ + + def __init__(self, input_dim: int, output_dim: int) -> None: + if input_dim != output_dim: + raise ValueError( + f"For identity projection, `input_dim` and `output_dim` must match, but are {input_dim} and {output_dim} instead." + ) + + super().__init__(input_dim, output_dim) + + @override + def forward(self, x: Tensor) -> Tensor: + return x + + +def _init_uniform(weight: Tensor, bias: Tensor | None) -> None: nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) if bias is not None: @@ -578,3 +604,11 @@ def _init_uniform_weight_and_bias(weight: Tensor, bias: Optional[Tensor]) -> Non bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(bias, -bound, bound) + + +def init_bert_projection(proj: Linear) -> None: + """Initialize ``proj`` as a projection to be used in BERT-like models.""" + nn.init.normal_(proj.weight, mean=0.0, std=0.02) + + if proj.bias is not None: + nn.init.zeros_(proj.bias) diff --git a/src/fairseq2/nn/transformer/__init__.py b/src/fairseq2/nn/transformer/__init__.py index 75f93d63d..28f1488a8 100644 --- a/src/fairseq2/nn/transformer/__init__.py +++ b/src/fairseq2/nn/transformer/__init__.py @@ -10,13 +10,13 @@ from fairseq2.nn.transformer.attention import NaiveSDPA as NaiveSDPA from fairseq2.nn.transformer.attention import SDPAFactory as SDPAFactory from fairseq2.nn.transformer.attention import TorchSDPA as TorchSDPA -from fairseq2.nn.transformer.attention import create_default_sdpa as create_default_sdpa from fairseq2.nn.transformer.attention import ( default_sdpa_factory as default_sdpa_factory, ) from fairseq2.nn.transformer.attention import ( enable_memory_efficient_torch_sdpa as enable_memory_efficient_torch_sdpa, ) +from fairseq2.nn.transformer.attention import make_default_sdpa as make_default_sdpa from fairseq2.nn.transformer.attention import ( set_default_sdpa_factory as set_default_sdpa_factory, ) @@ -64,6 +64,9 @@ from fairseq2.nn.transformer.encoder_layer import ( TransformerEncoderLayer as TransformerEncoderLayer, ) +from fairseq2.nn.transformer.ffn import ( + DauphinFeedForwardNetwork as DauphinFeedForwardNetwork, +) from fairseq2.nn.transformer.ffn import FeedForwardNetwork as FeedForwardNetwork from fairseq2.nn.transformer.ffn import GLUFeedForwardNetwork as GLUFeedForwardNetwork from fairseq2.nn.transformer.ffn import ( @@ -71,7 +74,7 @@ ) from fairseq2.nn.transformer.layer_norm import LayerNormFactory as LayerNormFactory from fairseq2.nn.transformer.layer_norm import ( - create_standard_layer_norm as create_standard_layer_norm, + make_standard_layer_norm as make_standard_layer_norm, ) from fairseq2.nn.transformer.multihead_attention import AttentionState as AttentionState from fairseq2.nn.transformer.multihead_attention import ( @@ -110,9 +113,27 @@ from fairseq2.nn.transformer.relative_attention import ( RelativePositionSDPA as RelativePositionSDPA, ) +from fairseq2.nn.transformer.residual import ( + DropPathResidualConnect as DropPathResidualConnect, +) +from fairseq2.nn.transformer.residual import ( + NormFormerResidualConnect as NormFormerResidualConnect, +) +from fairseq2.nn.transformer.residual import ResidualConnect as ResidualConnect +from fairseq2.nn.transformer.residual import ( + ScaledResidualConnect as ScaledResidualConnect, +) from fairseq2.nn.transformer.shaw_attention import ( ShawRelativePositionSDPA as ShawRelativePositionSDPA, ) from fairseq2.nn.transformer.shaw_attention import ( init_shaw_embedding as init_shaw_embedding, ) + +# isort: split + +# compat +from fairseq2.nn.transformer.attention import create_default_sdpa as create_default_sdpa +from fairseq2.nn.transformer.layer_norm import ( + create_standard_layer_norm as create_standard_layer_norm, +) diff --git a/src/fairseq2/nn/transformer/attention.py b/src/fairseq2/nn/transformer/attention.py index a089401e9..304fcb136 100644 --- a/src/fairseq2/nn/transformer/attention.py +++ b/src/fairseq2/nn/transformer/attention.py @@ -7,20 +7,19 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import contextmanager -from typing import Iterator, Optional, Protocol, Tuple, final +from typing import Protocol, final import torch from torch import Tensor from torch.nn import Module from torch.nn.functional import dropout, scaled_dot_product_attention, softmax +from typing_extensions import override -from fairseq2.logging import get_log_writer +from fairseq2.logging import log from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer.attention_mask import AttentionMask, CausalAttentionMask -from fairseq2.typing import override - -log = get_log_writer(__name__) class SDPA(Module, ABC): @@ -31,12 +30,12 @@ def forward( self, seqs: Tensor, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, + attn_mask: AttentionMask | None = None, needs_weights: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: """ :param seqs: The sequences to query. *Shape:* :math:`(N,H,S,K)`, where :math:`N` @@ -104,12 +103,12 @@ def forward( self, seqs: Tensor, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, + attn_mask: AttentionMask | None = None, needs_weights: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: if needs_weights: if not self._has_warned: log.warning("`TorchSDPA` has to fall back to the naive SDPA implementation because of `needs_weights` set to `True`.") # fmt: skip @@ -226,12 +225,12 @@ def forward( self, seqs: Tensor, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, + attn_mask: AttentionMask | None = None, needs_weights: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: return _naive_scaled_dot_product_attention( seqs, keys, @@ -251,13 +250,13 @@ def extra_repr(self) -> str: def _naive_scaled_dot_product_attention( seqs: Tensor, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, - attn_mask: Optional[AttentionMask], + attn_mask: AttentionMask | None, dropout_p: float, needs_weights: bool, training: bool, -) -> Tuple[Tensor, Optional[Tensor]]: +) -> tuple[Tensor, Tensor | None]: # (N, H, S, K) @ (N, H, K, S_kv) = (N, H, S, S_kv) attn_weights = torch.matmul(seqs, keys.transpose(-1, -2)) @@ -310,7 +309,7 @@ def _get_fallback_sdpa_factory() -> SDPAFactory: _sdpa_factory: SDPAFactory = _get_fallback_sdpa_factory() -def set_default_sdpa_factory(factory: Optional[SDPAFactory]) -> None: +def set_default_sdpa_factory(factory: SDPAFactory | None) -> None: """Set the default :class:`SDPA` factory.""" global _sdpa_factory @@ -320,8 +319,8 @@ def set_default_sdpa_factory(factory: Optional[SDPAFactory]) -> None: _sdpa_factory = _get_fallback_sdpa_factory() -def create_default_sdpa(*, attn_dropout_p: float = 0.0) -> SDPA: - """Create an instance of the default :class:`SDPA`. +def make_default_sdpa(*, attn_dropout_p: float = 0.0) -> SDPA: + """Make an instance of the default :class:`SDPA`. :param attn_dropout_p: The dropout probability on attention weights. @@ -329,8 +328,11 @@ def create_default_sdpa(*, attn_dropout_p: float = 0.0) -> SDPA: return _sdpa_factory(attn_dropout_p=attn_dropout_p) +create_default_sdpa = make_default_sdpa # compat + + @contextmanager -def default_sdpa_factory(factory: Optional[SDPAFactory]) -> Iterator[None]: +def default_sdpa_factory(factory: SDPAFactory | None) -> Iterator[None]: """Set a temporary default :class:`SDPA` factory.""" original_factory = _sdpa_factory diff --git a/src/fairseq2/nn/transformer/attention_mask.py b/src/fairseq2/nn/transformer/attention_mask.py index 7d593f22b..0f82cd4a6 100644 --- a/src/fairseq2/nn/transformer/attention_mask.py +++ b/src/fairseq2/nn/transformer/attention_mask.py @@ -7,13 +7,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Protocol, final +from typing import Protocol, final import torch from torch import Tensor +from typing_extensions import override from fairseq2.nn.incremental_state import IncrementalStateBag -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device class AttentionMask(ABC): @@ -27,7 +28,7 @@ def materialize(self) -> Tensor: class AbstractAttentionMask(AttentionMask): """Provides a skeletal implementation of :class:`AttentionMask`.""" - _materialized: Optional[Tensor] + _materialized: Tensor | None def __init__(self) -> None: self._materialized = None @@ -54,11 +55,11 @@ def __call__( keys: Tensor, *, training: bool = True, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Optional[AttentionMask]: + state_bag: IncrementalStateBag | None = None, + ) -> AttentionMask | None: """ :param seqs: - The sequences for which to create a mask. *Shape:* :math:`(N,S,M)`, + The sequences for which to make a mask. *Shape:* :math:`(N,S,M)`, where :math:`N` is the batch size, :math:`S` is the sequence length, and :math:`M` is the dimensionality of the model. :param keys: @@ -123,20 +124,20 @@ class CausalAttentionMask(AbstractAttentionMask): _seq_len: int _key_len: int - _attn_len: Optional[int] - _attn_window_len: Optional[int] - _device: Optional[Device] - _dtype: Optional[DataType] + _attn_len: int | None + _attn_window_len: int | None + _device: Device | None + _dtype: DataType | None def __init__( self, seq_len: int, key_len: int, *, - attn_len: Optional[int] = None, - attn_window_len: Optional[int] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + attn_len: int | None = None, + attn_window_len: int | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param seq_len: @@ -162,7 +163,7 @@ def __init__( @override def _do_materialize(self) -> Tensor: - return _create_causal_attention_mask( + return _make_causal_attention_mask( self._seq_len, self._key_len, self._attn_len, @@ -180,9 +181,9 @@ def full_attention(self) -> bool: class CausalAttentionMaskFactory(AttentionMaskFactory): """Constructs instances of :class:`CausalAttentionMask`.""" - _attn_window_len: Optional[int] + _attn_window_len: int | None - def __init__(self, *, attn_window_len: Optional[int] = None) -> None: + def __init__(self, *, attn_window_len: int | None = None) -> None: """ :param attn_window_len: The attention window length as described in Section 3.1 of @@ -197,9 +198,9 @@ def __call__( keys: Tensor, *, training: bool = True, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Optional[CausalAttentionMask]: - attn_len: Optional[int] + state_bag: IncrementalStateBag | None = None, + ) -> CausalAttentionMask | None: + attn_len: int | None attn_len = seqs.size(1) @@ -257,9 +258,9 @@ class ALiBiMask(AbstractAttentionMask): _seq_len: int _key_len: int _num_attn_heads: int - _attn_len: Optional[int] = None - _device: Optional[Device] = None - _dtype: Optional[DataType] = None + _attn_len: int | None = None + _device: Device | None = None + _dtype: DataType | None = None def __init__( self, @@ -267,9 +268,9 @@ def __init__( key_len: int, num_attn_heads: int, *, - attn_len: Optional[int] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + attn_len: int | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param seq_len: @@ -324,7 +325,7 @@ def _do_materialize(self) -> Tensor: mask[:, :, -causal:] = -torch.inf else: # (S, S_kv) - causal_mask = _create_causal_attention_mask( + causal_mask = _make_causal_attention_mask( self._seq_len, self._key_len, attn_len, None, self._device, self._dtype ) @@ -353,9 +354,9 @@ def __call__( keys: Tensor, *, training: bool = True, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Optional[ALiBiMask]: - attn_len: Optional[int] + state_bag: IncrementalStateBag | None = None, + ) -> ALiBiMask | None: + attn_len: int | None attn_len = seqs.size(1) @@ -390,13 +391,13 @@ def __repr__(self) -> str: return f"ALiBiMaskFactory(num_attn_heads={self._num_attn_heads})" -def _create_causal_attention_mask( +def _make_causal_attention_mask( seq_len: int, key_len: int, - attn_len: Optional[int], - attn_window_len: Optional[int], - device: Optional[Device], - dtype: Optional[DataType], + attn_len: int | None, + attn_window_len: int | None, + device: Device | None, + dtype: DataType | None, ) -> Tensor: if dtype is None: dtype = torch.get_default_dtype() diff --git a/src/fairseq2/nn/transformer/decoder.py b/src/fairseq2/nn/transformer/decoder.py index 6a26cea0d..25b16afac 100644 --- a/src/fairseq2/nn/transformer/decoder.py +++ b/src/fairseq2/nn/transformer/decoder.py @@ -8,13 +8,16 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, Iterable, Iterator, Optional, Protocol, Tuple, final +from collections.abc import Iterable, Iterator +from typing import Protocol, final import torch from torch import Generator, Tensor from torch.nn import Dropout, Module, ModuleList from torch.utils.hooks import RemovableHandle +from typing_extensions import override +from fairseq2.error import InvalidOperationError from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.padding import PaddingMask @@ -26,10 +29,10 @@ from fairseq2.nn.transformer.encoder import _record_drop_for_backward from fairseq2.nn.transformer.layer_norm import ( LayerNormFactory, - create_standard_layer_norm, + make_standard_layer_norm, ) from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import CPU, DataType, Device, override +from fairseq2.typing import CPU, DataType, Device class TransformerDecoder(Module, ABC): @@ -38,7 +41,7 @@ class TransformerDecoder(Module, ABC): model_dim: int layers: ModuleList - _layer_output_hooks: Dict[int, DecoderLayerOutputHook] + _layer_output_hooks: dict[int, DecoderLayerOutputHook] def __init__(self, model_dim: int) -> None: """ @@ -55,12 +58,12 @@ def __init__(self, model_dim: int) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - encoder_output: Optional[Tensor] = None, - encoder_padding_mask: Optional[PaddingMask] = None, + padding_mask: PaddingMask | None, + encoder_output: Tensor | None = None, + encoder_padding_mask: PaddingMask | None = None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: """ :param seqs: The sequences to decode. *Shape:* :math:`(N,S,M)`, where :math:`N` @@ -120,7 +123,7 @@ def __call__( self, layer_idx: int, layer_output: Tensor, - layer_padding_mask: Optional[PaddingMask], + layer_padding_mask: PaddingMask | None, num_layers: int, ) -> bool: """ @@ -145,10 +148,10 @@ class StandardTransformerDecoder(TransformerDecoder): """Represents a Transformer decoder as described in :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" - self_attn_mask_factory: Optional[AttentionMaskFactory] + self_attn_mask_factory: AttentionMaskFactory | None layer_drop_p: float - generator: Optional[Generator] - layer_norm: Optional[LayerNorm] + generator: Generator | None + layer_norm: LayerNorm | None dropout_p: float norm_order: TransformerNormOrder @@ -156,15 +159,15 @@ def __init__( self, layers: Iterable[TransformerDecoderLayer], *, - self_attn_mask_factory: Optional[AttentionMaskFactory] = None, + self_attn_mask_factory: AttentionMaskFactory | None = None, use_causal_attn_mask: bool = True, layer_drop_p: float = 0.0, - generator: Optional[Generator] = None, + generator: Generator | None = None, dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param layers: @@ -196,7 +199,7 @@ def __init__( super().__init__(model_dim) if layer_norm_factory is None: - layer_norm_factory = create_standard_layer_norm + layer_norm_factory = make_standard_layer_norm if self_attn_mask_factory is not None: self.self_attn_mask_factory = self_attn_mask_factory @@ -227,14 +230,14 @@ def __init__( def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - encoder_output: Optional[Tensor] = None, - encoder_padding_mask: Optional[PaddingMask] = None, + padding_mask: PaddingMask | None, + encoder_output: Tensor | None = None, + encoder_padding_mask: PaddingMask | None = None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: if self._layer_output_hooks and self.layer_drop_p > 0.0 and self.training: - raise RuntimeError( + raise InvalidOperationError( "The layer output hooks cannot be run when LayerDrop is enabled." ) @@ -276,7 +279,7 @@ def forward( return seqs, padding_mask - def _drop_iter(self) -> Iterator[Tuple[Module, bool]]: + def _drop_iter(self) -> Iterator[tuple[Module, bool]]: if self.training and self.layer_drop_p > 0.0: prob_dist = torch.rand( len(self.layers), generator=self.generator, device=CPU diff --git a/src/fairseq2/nn/transformer/decoder_layer.py b/src/fairseq2/nn/transformer/decoder_layer.py index 45cf9d808..76b9498c0 100644 --- a/src/fairseq2/nn/transformer/decoder_layer.py +++ b/src/fairseq2/nn/transformer/decoder_layer.py @@ -7,13 +7,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Tuple, cast, final +from typing import final -import torch -import torch.nn as nn from torch import Tensor from torch.nn import Dropout, Module -from torch.nn.parameter import Parameter +from typing_extensions import override from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.normalization import LayerNorm @@ -22,11 +20,12 @@ from fairseq2.nn.transformer.ffn import FeedForwardNetwork from fairseq2.nn.transformer.layer_norm import ( LayerNormFactory, - create_standard_layer_norm, + make_standard_layer_norm, ) from fairseq2.nn.transformer.multihead_attention import MultiheadAttention from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import DataType, Device, override +from fairseq2.nn.transformer.residual import ResidualConnect, StandardResidualConnect +from fairseq2.typing import DataType, Device class TransformerDecoderLayer(Module, ABC): @@ -47,13 +46,13 @@ def __init__(self, model_dim: int) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask] = None, - encoder_output: Optional[Tensor] = None, - encoder_padding_mask: Optional[PaddingMask] = None, + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + encoder_output: Tensor | None = None, + encoder_padding_mask: PaddingMask | None = None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: """ :param seqs: The sequences to process. *Shape:* :math:`(N,S,M)`, where :math:`N` @@ -95,30 +94,34 @@ class StandardTransformerDecoderLayer(TransformerDecoderLayer): :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" self_attn: MultiheadAttention - self_attn_norm: Optional[LayerNorm] - self_attn_dropout: Optional[Dropout] + self_attn_norm: LayerNorm | None + self_attn_dropout: Dropout | None + self_attn_residual: ResidualConnect self_attn_layer_norm: LayerNorm - encoder_decoder_attn: Optional[MultiheadAttention] - encoder_decoder_attn_dropout: Optional[Dropout] - encoder_decoder_attn_layer_norm: Optional[LayerNorm] + encoder_decoder_attn: MultiheadAttention | None + encoder_decoder_attn_dropout: Dropout | None + encoder_decoder_attn_residual: ResidualConnect | None + encoder_decoder_attn_layer_norm: LayerNorm | None ffn: FeedForwardNetwork - ffn_dropout: Optional[Dropout] - residual_scale: Optional[Parameter] + ffn_dropout: Dropout | None + ffn_residual: ResidualConnect ffn_layer_norm: LayerNorm norm_order: TransformerNormOrder def __init__( self, self_attn: MultiheadAttention, - encoder_decoder_attn: Optional[MultiheadAttention], + encoder_decoder_attn: MultiheadAttention | None, ffn: FeedForwardNetwork, *, - scale_residual: bool = False, dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + self_attn_residual: ResidualConnect | None = None, + encoder_decoder_attn_residual: ResidualConnect | None = None, + ffn_residual: ResidualConnect | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param self_attn: @@ -127,10 +130,6 @@ def __init__( The encoder-decoder attention layer. :param ffn: The feed-forward network. - :param scale_residual: - If ``True``, scales residuals before adding them to the output of - the feed-forward network as described in - :cite:t:`https://doi.org/10.48550/arxiv.2110.09456`. :param dropout_p: The dropout probability on outputs of the attention layers and the feed-forward network. @@ -138,13 +137,22 @@ def __init__( The Layer Normalization order. :param layer_norm_factory: The factory to construct the Layer Normalization modules. + :param self_attn_residual: + The residual connection between the input and output of the self + attention layer. + :param encoder_decoder_attn_residual: + The residual connection between the input and output of the + encoder-decoder attention layer. + :param ffn_residual: + The residual connection between the input and output of the + feed-forward network. """ model_dim = self_attn.model_dim super().__init__(model_dim) if layer_norm_factory is None: - layer_norm_factory = create_standard_layer_norm + layer_norm_factory = make_standard_layer_norm self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) @@ -165,11 +173,18 @@ def __init__( else: self.register_module("self_attn_dropout", None) + if self_attn_residual is None: + self_attn_residual = StandardResidualConnect() + + self.self_attn_residual = self_attn_residual + if norm_order == TransformerNormOrder.POST: self.self_attn_layer_norm = self_attn_layer_norm if encoder_decoder_attn is None: self.register_module("encoder_decoder_attn", None) + self.register_module("encoder_decoder_attn_dropout", None) + self.register_module("encoder_decoder_attn_residual", None) self.register_module("encoder_decoder_attn_layer_norm", None) else: encoder_decoder_attn_layer_norm = layer_norm_factory( @@ -186,6 +201,11 @@ def __init__( else: self.register_module("encoder_decoder_attn_dropout", None) + if encoder_decoder_attn_residual is None: + encoder_decoder_attn_residual = StandardResidualConnect() + + self.encoder_decoder_attn_residual = encoder_decoder_attn_residual + if norm_order == TransformerNormOrder.POST: self.encoder_decoder_attn_layer_norm = encoder_decoder_attn_layer_norm @@ -201,36 +221,27 @@ def __init__( else: self.register_module("ffn_dropout", None) - if scale_residual: - self.residual_scale = Parameter( - torch.empty((model_dim,), device=device, dtype=dtype) - ) - else: - self.register_parameter("residual_scale", None) + if ffn_residual is None: + ffn_residual = StandardResidualConnect() + + self.ffn_residual = ffn_residual if norm_order == TransformerNormOrder.POST: self.ffn_layer_norm = ffn_layer_norm self.norm_order = norm_order - self.reset_parameters() - - def reset_parameters(self) -> None: - """Reset the parameters and buffers of the module.""" - if self.residual_scale is not None: - nn.init.ones_(self.residual_scale) - @override def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask] = None, - encoder_output: Optional[Tensor] = None, - encoder_padding_mask: Optional[PaddingMask] = None, + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + encoder_output: Tensor | None = None, + encoder_padding_mask: PaddingMask | None = None, *, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, PaddingMask | None]: seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask, state_bag) seqs = self._forward_encoder_decoder_attn( @@ -244,9 +255,9 @@ def forward( def _forward_self_attn( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -269,7 +280,7 @@ def _forward_self_attn( if self.self_attn_dropout is not None: seqs = self.self_attn_dropout(seqs) - seqs = seqs + residual + seqs = self.self_attn_residual(seqs, residual) if self.norm_order == TransformerNormOrder.POST: seqs = self.self_attn_layer_norm(seqs) @@ -279,10 +290,10 @@ def _forward_self_attn( def _forward_encoder_decoder_attn( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - encoder_output: Optional[Tensor], - encoder_padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag], + padding_mask: PaddingMask | None, + encoder_output: Tensor | None, + encoder_padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None, ) -> Tensor: if self.encoder_decoder_attn is None: if encoder_output is not None: @@ -297,10 +308,13 @@ def _forward_encoder_decoder_attn( "`encoder_output` must not be `None` for encoder-decoder attention." ) + assert self.encoder_decoder_attn_residual is not None + assert self.encoder_decoder_attn_layer_norm is not None + residual = seqs if self.norm_order != TransformerNormOrder.POST: - seqs = cast(LayerNorm, self.encoder_decoder_attn_layer_norm)(seqs) + seqs = self.encoder_decoder_attn_layer_norm(seqs) seqs = self.encoder_decoder_attn( seqs, @@ -314,10 +328,10 @@ def _forward_encoder_decoder_attn( if self.encoder_decoder_attn_dropout is not None: seqs = self.encoder_decoder_attn_dropout(seqs) - seqs = seqs + residual + seqs = self.encoder_decoder_attn_residual(seqs, residual) if self.norm_order == TransformerNormOrder.POST: - seqs = cast(LayerNorm, self.encoder_decoder_attn_layer_norm)(seqs) + seqs = self.encoder_decoder_attn_layer_norm(seqs) return seqs @@ -332,10 +346,7 @@ def _forward_ffn(self, seqs: Tensor) -> Tensor: if self.ffn_dropout is not None: seqs = self.ffn_dropout(seqs) - if self.residual_scale is not None: - residual = self.residual_scale * residual - - seqs = seqs + residual + seqs = self.ffn_residual(seqs, residual) if self.norm_order == TransformerNormOrder.POST: seqs = self.ffn_layer_norm(seqs) diff --git a/src/fairseq2/nn/transformer/encoder.py b/src/fairseq2/nn/transformer/encoder.py index e76e961e3..0827c41e6 100644 --- a/src/fairseq2/nn/transformer/encoder.py +++ b/src/fairseq2/nn/transformer/encoder.py @@ -8,24 +8,27 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Dict, Iterable, Iterator, Optional, Protocol, Tuple, final +from collections.abc import Iterable, Iterator +from typing import Any, Protocol, final import torch from torch import Generator, Tensor from torch.autograd import Function from torch.nn import Dropout, Module, ModuleList from torch.utils.hooks import RemovableHandle +from typing_extensions import override +from fairseq2.error import InvalidOperationError from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer.attention_mask import AttentionMaskFactory from fairseq2.nn.transformer.encoder_layer import TransformerEncoderLayer from fairseq2.nn.transformer.layer_norm import ( LayerNormFactory, - create_standard_layer_norm, + make_standard_layer_norm, ) from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import CPU, DataType, Device, override +from fairseq2.typing import CPU, DataType, Device class TransformerEncoder(Module, ABC): @@ -34,7 +37,7 @@ class TransformerEncoder(Module, ABC): model_dim: int layers: ModuleList - _layer_output_hooks: Dict[int, EncoderLayerOutputHook] + _layer_output_hooks: dict[int, EncoderLayerOutputHook] def __init__(self, model_dim: int) -> None: """ @@ -49,8 +52,8 @@ def __init__(self, model_dim: int) -> None: @abstractmethod def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: """ :param seqs: The sequences to encode. *Shape:* :math:`(N,S,M)`, where :math:`N` @@ -100,7 +103,7 @@ def __call__( self, layer_idx: int, layer_output: Tensor, - layer_padding_mask: Optional[PaddingMask], + layer_padding_mask: PaddingMask | None, num_layers: int, ) -> bool: """ @@ -125,10 +128,10 @@ class StandardTransformerEncoder(TransformerEncoder): """Represents a Transformer encoder as described in :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`.""" - self_attn_mask_factory: Optional[AttentionMaskFactory] + self_attn_mask_factory: AttentionMaskFactory | None layer_drop_p: float - generator: Optional[Generator] - layer_norm: Optional[LayerNorm] + generator: Generator | None + layer_norm: LayerNorm | None dropout_p: float norm_order: TransformerNormOrder @@ -136,14 +139,14 @@ def __init__( self, layers: Iterable[TransformerEncoderLayer], *, - self_attn_mask_factory: Optional[AttentionMaskFactory] = None, + self_attn_mask_factory: AttentionMaskFactory | None = None, layer_drop_p: float = 0.0, - generator: Optional[Generator] = None, + generator: Generator | None = None, dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param layers: @@ -171,7 +174,7 @@ def __init__( super().__init__(model_dim) if layer_norm_factory is None: - layer_norm_factory = create_standard_layer_norm + layer_norm_factory = make_standard_layer_norm self.self_attn_mask_factory = self_attn_mask_factory @@ -195,10 +198,10 @@ def __init__( @override def forward( - self, seqs: Tensor, padding_mask: Optional[PaddingMask] - ) -> Tuple[Tensor, Optional[PaddingMask]]: + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: if self._layer_output_hooks and self.layer_drop_p > 0.0 and self.training: - raise RuntimeError( + raise InvalidOperationError( "The layer output hooks cannot be run when LayerDrop is enabled." ) @@ -233,7 +236,7 @@ def forward( return seqs, padding_mask - def _drop_iter(self) -> Iterator[Tuple[Module, bool]]: + def _drop_iter(self) -> Iterator[tuple[Module, bool]]: if self.training and self.layer_drop_p > 0.0: prob_dist = torch.rand( len(self.layers), generator=self.generator, device=CPU @@ -276,5 +279,5 @@ def forward(ctx: Any, x: Tensor, dropped_output: Tensor) -> Tensor: return x @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor]: + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, Tensor]: return grad_output, torch.zeros_like(grad_output) diff --git a/src/fairseq2/nn/transformer/encoder_layer.py b/src/fairseq2/nn/transformer/encoder_layer.py index d454ac5ad..580af8748 100644 --- a/src/fairseq2/nn/transformer/encoder_layer.py +++ b/src/fairseq2/nn/transformer/encoder_layer.py @@ -7,13 +7,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Tuple, final +from typing import final -import torch -import torch.nn as nn from torch import Tensor from torch.nn import Dropout, Module -from torch.nn.parameter import Parameter +from typing_extensions import override from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.padding import PaddingMask @@ -21,11 +19,12 @@ from fairseq2.nn.transformer.ffn import FeedForwardNetwork from fairseq2.nn.transformer.layer_norm import ( LayerNormFactory, - create_standard_layer_norm, + make_standard_layer_norm, ) from fairseq2.nn.transformer.multihead_attention import MultiheadAttention from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import DataType, Device, override +from fairseq2.nn.transformer.residual import ResidualConnect, StandardResidualConnect +from fairseq2.typing import DataType, Device class TransformerEncoderLayer(Module, ABC): @@ -46,9 +45,9 @@ def __init__(self, model_dim: int) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + ) -> tuple[Tensor, PaddingMask | None]: """ :param seqs: The sequences to process. *Shape:* :math:`(N,S,M)`, where :math:`N` @@ -80,12 +79,13 @@ class StandardTransformerEncoderLayer(TransformerEncoderLayer): """ self_attn: MultiheadAttention - self_attn_norm: Optional[LayerNorm] - self_attn_dropout: Optional[Dropout] + self_attn_norm: LayerNorm | None + self_attn_dropout: Dropout | None + self_attn_residual: ResidualConnect self_attn_layer_norm: LayerNorm ffn: FeedForwardNetwork - ffn_dropout: Optional[Dropout] - residual_scale: Optional[Parameter] + ffn_dropout: Dropout | None + ffn_residual: ResidualConnect ffn_layer_norm: LayerNorm norm_order: TransformerNormOrder @@ -94,22 +94,19 @@ def __init__( self_attn: MultiheadAttention, ffn: FeedForwardNetwork, *, - scale_residual: bool = False, dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + self_attn_residual: ResidualConnect | None = None, + ffn_residual: ResidualConnect | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param self_attn: The self attention layer. :param ffn: The feed-forward network. - :param scale_residual: - If ``True``, scales residuals before adding them to the output of - the feed-forward network as described in - :cite:t:`https://doi.org/10.48550/arxiv.2110.09456`. :param dropout_p: The dropout probability on outputs of the self attention layer and the feed-forward network. @@ -117,13 +114,19 @@ def __init__( The Layer Normalization order. :param layer_norm_factory: The factory to construct the Layer Normalization modules. + :param self_attn_residual: + The residual connection between the input and output of the self + attention layer. + :param ffn_residual: + The residual connection between the input and output of the + feed-forward network. """ model_dim = self_attn.model_dim super().__init__(model_dim) if layer_norm_factory is None: - layer_norm_factory = create_standard_layer_norm + layer_norm_factory = make_standard_layer_norm self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) @@ -144,6 +147,11 @@ def __init__( else: self.register_module("self_attn_dropout", None) + if self_attn_residual is None: + self_attn_residual = StandardResidualConnect() + + self.self_attn_residual = self_attn_residual + if norm_order == TransformerNormOrder.POST: self.self_attn_layer_norm = self_attn_layer_norm @@ -159,32 +167,23 @@ def __init__( else: self.register_module("ffn_dropout", None) - if scale_residual: - self.residual_scale = Parameter( - torch.empty((model_dim,), device=device, dtype=dtype) - ) - else: - self.register_parameter("residual_scale", None) + if ffn_residual is None: + ffn_residual = StandardResidualConnect() + + self.ffn_residual = ffn_residual if norm_order == TransformerNormOrder.POST: self.ffn_layer_norm = ffn_layer_norm self.norm_order = norm_order - self.reset_parameters() - - def reset_parameters(self) -> None: - """Reset the parameters and buffers of the module.""" - if self.residual_scale is not None: - nn.init.ones_(self.residual_scale) - @override def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask] = None, - ) -> Tuple[Tensor, Optional[PaddingMask]]: + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + ) -> tuple[Tensor, PaddingMask | None]: seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask) seqs = self._forward_ffn(seqs) @@ -194,8 +193,8 @@ def forward( def _forward_self_attn( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - self_attn_mask: Optional[AttentionMask], + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None, ) -> Tensor: residual = seqs @@ -217,7 +216,7 @@ def _forward_self_attn( if self.self_attn_dropout is not None: seqs = self.self_attn_dropout(seqs) - seqs = seqs + residual + seqs = self.self_attn_residual(seqs, residual) if self.norm_order == TransformerNormOrder.POST: seqs = self.self_attn_layer_norm(seqs) @@ -235,10 +234,7 @@ def _forward_ffn(self, seqs: Tensor) -> Tensor: if self.ffn_dropout is not None: seqs = self.ffn_dropout(seqs) - if self.residual_scale is not None: - residual = self.residual_scale * residual - - seqs = seqs + residual + seqs = self.ffn_residual(seqs, residual) if self.norm_order == TransformerNormOrder.POST: seqs = self.ffn_layer_norm(seqs) diff --git a/src/fairseq2/nn/transformer/ffn.py b/src/fairseq2/nn/transformer/ffn.py index 732922a5c..2c04d70f3 100644 --- a/src/fairseq2/nn/transformer/ffn.py +++ b/src/fairseq2/nn/transformer/ffn.py @@ -7,19 +7,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, final +from collections.abc import Callable +from typing import final from torch import Tensor -from torch.nn import Dropout, Module, ReLU, SiLU +from torch.nn import Dropout, Module, ReLU, Sigmoid, SiLU +from typing_extensions import override from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.projection import Linear, Projection from fairseq2.nn.transformer.layer_norm import ( LayerNormFactory, - create_standard_layer_norm, + make_standard_layer_norm, ) from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device class FeedForwardNetwork(Module, ABC): @@ -60,8 +62,8 @@ class StandardFeedForwardNetwork(FeedForwardNetwork): inner_proj: Projection inner_activation: Module - inner_dropout: Optional[Dropout] - inner_norm: Optional[LayerNorm] + inner_dropout: Dropout | None + inner_norm: LayerNorm | None output_proj: Projection def __init__( @@ -70,12 +72,13 @@ def __init__( inner_dim: int, bias: bool, *, - inner_activation: Optional[Module] = None, + inner_activation: Module | None = None, inner_dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, - layer_norm_factory: Optional[LayerNormFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + layer_norm_factory: LayerNormFactory | None = None, + proj_init_fn: Callable[[Linear], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -94,13 +97,17 @@ def __init__( The Layer Normalization order. :param layer_norm_factory: The factory to construct the Layer Normalization module. + :param proj_init_fn: + The callable to initialize the inner and output projections. """ super().__init__(model_dim) if layer_norm_factory is None: - layer_norm_factory = create_standard_layer_norm + layer_norm_factory = make_standard_layer_norm - self.inner_proj = Linear(model_dim, inner_dim, bias, device=device, dtype=dtype) + self.inner_proj = Linear( + model_dim, inner_dim, bias, init_fn=proj_init_fn, device=device, dtype=dtype + ) if inner_activation is None: self.inner_activation = ReLU() @@ -120,7 +127,7 @@ def __init__( self.register_module("inner_layer_norm", None) self.output_proj = Linear( - inner_dim, model_dim, bias, device=device, dtype=dtype + inner_dim, model_dim, bias, init_fn=proj_init_fn, device=device, dtype=dtype ) @override @@ -140,6 +147,85 @@ def forward(self, seqs: Tensor) -> Tensor: return seqs +@final +class DauphinFeedForwardNetwork(FeedForwardNetwork): + """Represents a GLU-based Transformer feed-forward network as described in + :cite:t:`https://doi.org/10.48550/arXiv.1612.08083`""" + + inner_proj: Projection + inner_activation: Module + inner_dropout: Dropout | None + output_proj: Projection + + def __init__( + self, + model_dim: int, + inner_dim: int, + bias: bool, + *, + inner_activation: Module | None = None, + inner_dropout_p: float = 0.0, + proj_init_fn: Callable[[Linear], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param model_dim: + The dimensionality of the model. + :param inner_dim: + The dimensionality of the inner projection layer. + :param bias: + If ``True``, both the inner and output projection learn an additive + bias. + :param inner_activation: + The activation to apply to outputs of the inner projection layer. If + ``None``, :func:`~torch.nn.Sigmoid` will be used. + :param inner_dropout_p: + The dropout probability on outputs of the inner projection layer. + :param proj_init_fn: + The callable to initialize the inner and output projections. + """ + super().__init__(model_dim) + + self.inner_proj = Linear( + model_dim, + inner_dim * 2, + bias, + init_fn=proj_init_fn, + device=device, + dtype=dtype, + ) + + if inner_activation is None: + self.inner_activation = Sigmoid() + else: + self.inner_activation = inner_activation + + if inner_dropout_p > 0.0: + self.inner_dropout = Dropout(inner_dropout_p) + else: + self.register_module("inner_dropout", None) + + self.output_proj = Linear( + inner_dim, model_dim, bias, device=device, dtype=dtype + ) + + @override + def forward(self, seqs: Tensor) -> Tensor: + seqs = self.inner_proj(seqs) + + split1, split2 = seqs.chunk(2, dim=-1) + + seqs = self.inner_activation(split1) * split2 # gate + + if self.inner_dropout is not None: + seqs = self.inner_dropout(seqs) + + seqs = self.output_proj(seqs) + + return seqs + + @final class GLUFeedForwardNetwork(FeedForwardNetwork): """Represents a GLU-based Transformer feed-forward network as described in @@ -150,7 +236,7 @@ class GLUFeedForwardNetwork(FeedForwardNetwork): inner_dim_scale: float inner_dim_to_multiple: int inner_proj: Projection - inner_dropout: Optional[Dropout] + inner_dropout: Dropout | None output_proj: Projection def __init__( @@ -159,12 +245,12 @@ def __init__( inner_dim: int, bias: bool, *, - gate_activation: Optional[Module] = None, + gate_activation: Module | None = None, inner_dim_scale: float = 2 / 3, inner_dim_to_multiple: int = 1, inner_dropout_p: float = 0.0, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: diff --git a/src/fairseq2/nn/transformer/layer_norm.py b/src/fairseq2/nn/transformer/layer_norm.py index 0210ebebb..794b1d321 100644 --- a/src/fairseq2/nn/transformer/layer_norm.py +++ b/src/fairseq2/nn/transformer/layer_norm.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Protocol +from typing import Protocol from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm from fairseq2.typing import DataType, Device @@ -19,8 +19,8 @@ def __call__( self, model_dim: int, *, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> LayerNorm: """ :param model_dim: @@ -32,8 +32,11 @@ def __call__( """ -def create_standard_layer_norm( - model_dim: int, *, device: Optional[Device] = None, dtype: Optional[DataType] = None +def make_standard_layer_norm( + model_dim: int, *, device: Device | None = None, dtype: DataType | None = None ) -> LayerNorm: - """Create a :class:`StandardLayerNorm` instance.""" + """Make a :class:`StandardLayerNorm` instance.""" return StandardLayerNorm(model_dim, bias=True, device=device, dtype=dtype) + + +create_standard_layer_norm = make_standard_layer_norm # compat diff --git a/src/fairseq2/nn/transformer/multihead_attention.py b/src/fairseq2/nn/transformer/multihead_attention.py index c6841bc87..9d31c5b60 100644 --- a/src/fairseq2/nn/transformer/multihead_attention.py +++ b/src/fairseq2/nn/transformer/multihead_attention.py @@ -8,7 +8,8 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Dict, MutableSequence, Optional, Protocol, Tuple, final +from collections.abc import Callable, MutableSequence +from typing import Protocol, final import torch import torch.nn as nn @@ -16,15 +17,17 @@ from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle +from typing_extensions import override +from fairseq2.error import NotSupportedError from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag from fairseq2.nn.ops import repeat_interleave from fairseq2.nn.padding import PaddingMask from fairseq2.nn.position_encoder import PositionEncoder from fairseq2.nn.projection import Linear, Projection -from fairseq2.nn.transformer.attention import SDPA, create_default_sdpa +from fairseq2.nn.transformer.attention import SDPA, make_default_sdpa from fairseq2.nn.transformer.attention_mask import AttentionMask, AttentionMaskFactory -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device class MultiheadAttention(Module, ABC): @@ -33,7 +36,7 @@ class MultiheadAttention(Module, ABC): num_heads: int model_dim: int - _attn_weight_hooks: Dict[int, AttentionWeightHook] + _attn_weight_hooks: dict[int, AttentionWeightHook] def __init__(self, model_dim: int, num_heads: int) -> None: """ @@ -53,13 +56,13 @@ def __init__(self, model_dim: int, num_heads: int) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, - state_bag: Optional[IncrementalStateBag] = None, + attn_mask: AttentionMask | None = None, + state_bag: IncrementalStateBag | None = None, ) -> Tensor: """ :param seqs: @@ -149,9 +152,9 @@ class AttentionWeightStoreHook(AttentionWeightHook): This class follows the :class:`AttentionWeightHook` protocol. """ - _storage: MutableSequence[Tuple[Tensor, Tensor]] + _storage: MutableSequence[tuple[Tensor, Tensor]] - def __init__(self, storage: MutableSequence[Tuple[Tensor, Tensor]]) -> None: + def __init__(self, storage: MutableSequence[tuple[Tensor, Tensor]]) -> None: """ :param storage: The storage in which to store attention weights. @@ -174,35 +177,35 @@ class StandardMultiheadAttention(MultiheadAttention): q_proj: Projection k_proj: Projection v_proj: Projection - attn_mask_factory: Optional[AttentionMaskFactory] - pos_encoder: Optional[PositionEncoder] - bias_k: Optional[Parameter] - bias_v: Optional[Parameter] - add_zero_attn: bool + attn_mask_factory: AttentionMaskFactory | None + pos_encoder: PositionEncoder | None sdpa: SDPA - head_scale_weight: Optional[Parameter] + head_scale_weight: Parameter | None output_proj: Projection - state_factory: Optional[AttentionStateFactory] + state_factory: AttentionStateFactory | None def __init__( self, model_dim: int, num_heads: int, *, - kv_dim: Optional[int] = None, - num_key_value_heads: Optional[int] = None, - q_proj: Optional[Projection] = None, - k_proj: Optional[Projection] = None, - v_proj: Optional[Projection] = None, - attn_mask_factory: Optional[AttentionMaskFactory] = None, - pos_encoder: Optional[PositionEncoder] = None, - sdpa: Optional[SDPA] = None, + kv_dim: int | None = None, + num_key_value_heads: int | None = None, + q_proj: Projection | None = None, + k_proj: Projection | None = None, + v_proj: Projection | None = None, + qkv_proj_init_fn: Callable[[Linear], None] | None = None, + attn_mask_factory: AttentionMaskFactory | None = None, + pos_encoder: PositionEncoder | None = None, + sdpa: SDPA | None = None, scale_heads: bool = False, - output_proj: Optional[Projection] = None, + output_proj: Projection | None = None, + output_proj_init_fn: Callable[[Linear], None] | None = None, bias: bool = True, - state_factory: Optional[AttentionStateFactory] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + output_proj_bias: bool | None = None, + state_factory: AttentionStateFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -227,6 +230,8 @@ def __init__( :param v_proj: The projection to apply to values before computing attention. If ``None``, a default projection will be used. + :param qkv_proj_init_fn: + The callable to initialize the q, k, v projections. :param attn_mask_factory: The attention mask factory. :param pos_encoder: @@ -240,9 +245,15 @@ def __init__( :param output_proj: The projection to produce final attentions. If ``None``, a default projection will be used. + :param output_proj_init_fn: + The callable to initialize the output projection. :param bias: - If ``True``, query, key, value, and output projections learn an - additive bias. Ignored for explicitly specified projections. + If ``True``, query, key, and value projections learn an additive + bias. Ignored for explicitly specified projections. + :param output_proj_bias: + If ``True``, output projection learns an additive bias. If ``None``, + the value of ``bias`` is used. Ignored for explicitly specified + projections. :param state_factory: The factory to construct :class:`AttentionState` instances for incremental decoding. @@ -275,7 +286,7 @@ def __init__( model_dim, model_dim, bias, - init_fn=init_qkv_projection, + init_fn=qkv_proj_init_fn or init_qkv_projection, device=device, dtype=dtype, ) @@ -283,7 +294,7 @@ def __init__( self.kv_dim, head_dim * self.num_key_value_heads, bias, - init_fn=init_qkv_projection, + init_fn=qkv_proj_init_fn or init_qkv_projection, device=device, dtype=dtype, ) @@ -291,14 +302,17 @@ def __init__( self.kv_dim, head_dim * self.num_key_value_heads, bias, - init_fn=init_qkv_projection, + init_fn=qkv_proj_init_fn or init_qkv_projection, device=device, dtype=dtype, ) else: if q_proj is None or k_proj is None or v_proj is None: + raise ValueError("`q_proj`, `k_proj`, `v_proj` must be all specified.") + + if qkv_proj_init_fn is not None: raise ValueError( - "`q_proj`, `k_proj`, and `v_proj` must be all specified." + "`qkv_proj_init_fn` must be `None`, when `q_proj`, `k_proj`, `v_proj` are specified." ) if q_proj.input_dim != self.kv_dim: @@ -340,7 +354,7 @@ def __init__( if sdpa is not None: self.sdpa = sdpa else: - self.sdpa = create_default_sdpa() + self.sdpa = make_default_sdpa() if scale_heads: self.head_scale_weight = Parameter( @@ -352,15 +366,23 @@ def __init__( v_dim = v_proj.output_dim * num_query_groups if output_proj is None: + if output_proj_bias is None: + output_proj_bias = bias + self.output_proj = Linear( v_dim, model_dim, - bias, - init_fn=init_output_projection, + output_proj_bias, + init_fn=output_proj_init_fn or init_output_projection, device=device, dtype=dtype, ) else: + if output_proj_init_fn is not None: + raise ValueError( + "`output_proj_init_fn` must be `None`, when `output_proj` is specified." + ) + if v_dim != output_proj.input_dim: raise ValueError( f"`output_dim` of `v_proj` (times the number of query groups when GQA) and `input_dim` of `output_proj` must be equal, but are {v_dim} and {output_proj.input_dim} instead." @@ -386,13 +408,13 @@ def reset_parameters(self) -> None: def forward( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], + padding_mask: PaddingMask | None, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, - state_bag: Optional[IncrementalStateBag] = None, + attn_mask: AttentionMask | None = None, + state_bag: IncrementalStateBag | None = None, ) -> Tensor: # (N, S, M) -> (N, H, S, K_h) q = self._project_q(seqs, padding_mask, state_bag) @@ -492,8 +514,8 @@ def forward( def _project_q( self, seqs: Tensor, - padding_mask: Optional[PaddingMask], - state_bag: Optional[IncrementalStateBag] = None, + padding_mask: PaddingMask | None, + state_bag: IncrementalStateBag | None = None, ) -> Tensor: # (N, S, M) -> (N, S, K_proj) q = self.q_proj(seqs) @@ -509,10 +531,10 @@ def _project_q( def _project_kv( self, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, - state_bag: Optional[IncrementalStateBag] = None, - ) -> Tuple[Tensor, Tensor]: + state_bag: IncrementalStateBag | None = None, + ) -> tuple[Tensor, Tensor]: # (N, S, K) -> (N, S, K_proj) k = self.k_proj(keys) # (N, S, V) -> (N, S, V_proj) @@ -582,7 +604,7 @@ def append(self, k: Tensor, v: Tensor) -> None: """ @abstractmethod - def get(self) -> Tuple[Tensor, Tensor]: + def get(self) -> tuple[Tensor, Tensor]: """Return the state that should be used to compute the attention. :returns: @@ -595,7 +617,7 @@ class AttentionStateFactory(Protocol): """Constructs instances of :class:`AttentionState`.""" def __call__( - self, k: Tensor, v: Tensor, max_seq_len: int, capacity_increment: Optional[int] + self, k: Tensor, v: Tensor, max_seq_len: int, capacity_increment: int | None ) -> AttentionState: """ :param k: @@ -631,12 +653,12 @@ class FullAttentionState(AttentionState): is the number of heads, :math:`S_{rsv}` is the reserved sequence length capacity, and :math:`V_{proj}` is the projected value size.""" - _capacity_increment: Optional[int] + _capacity_increment: int | None """The sequence length capacity of :attr:`k` and :attr:`v` is incremented by multiples of this value.""" def __init__( - self, k: Tensor, v: Tensor, max_seq_len: int, capacity_increment: Optional[int] + self, k: Tensor, v: Tensor, max_seq_len: int, capacity_increment: int | None ) -> None: if capacity_increment is not None and capacity_increment < 1: raise ValueError( @@ -700,7 +722,7 @@ def _expand_kv(self, input_seq_len: int) -> None: self._v = v @override - def get(self) -> Tuple[Tensor, Tensor]: + def get(self) -> tuple[Tensor, Tensor]: k = self._k[:, :, : self._seq_len] v = self._v[:, :, : self._seq_len] @@ -753,7 +775,7 @@ class LocalAttentionState(AttentionState): reserved sequence length capacity, and :math:`V_{proj}` is the projected value size.""" - _capacity_increment: Optional[int] + _capacity_increment: int | None """The sequence length capacity of :attr:`k` and :attr:`v` is incremented by multiples of this value.""" @@ -763,7 +785,7 @@ def __init__( v: Tensor, max_seq_len: int, attn_window_len: int, - capacity_increment: Optional[int], + capacity_increment: int | None, ) -> None: if capacity_increment is not None and capacity_increment < 1: raise ValueError( @@ -844,7 +866,7 @@ def _expand_kv(self, input_seq_len: int) -> None: self._v = v @override - def get(self) -> Tuple[Tensor, Tensor]: + def get(self) -> tuple[Tensor, Tensor]: k = self._k[:, :, : self._seq_len] v = self._v[:, :, : self._seq_len] @@ -889,7 +911,7 @@ def __call__( k: Tensor, v: Tensor, max_seq_len: int, - capacity_increment: Optional[int], + capacity_increment: int | None, ) -> LocalAttentionState: return LocalAttentionState( k, v, max_seq_len, self._attn_window_len, capacity_increment @@ -908,17 +930,17 @@ class StaticAttentionState(AttentionState): _v: Tensor def __init__( - self, k: Tensor, v: Tensor, max_seq_len: int, capacity_increment: Optional[int] + self, k: Tensor, v: Tensor, max_seq_len: int, capacity_increment: int | None ) -> None: self._k = k self._v = v @override def append(self, k: Tensor, v: Tensor) -> None: - raise ValueError(" `StaticAttentionState` does not support `append()`.") + raise NotSupportedError(f"`{type(self)}` does not support `append()`.") @override - def get(self) -> Tuple[Tensor, Tensor]: + def get(self) -> tuple[Tensor, Tensor]: return self._k, self._v @override diff --git a/src/fairseq2/nn/transformer/relative_attention.py b/src/fairseq2/nn/transformer/relative_attention.py index 0e07ce33b..c9853fbfc 100644 --- a/src/fairseq2/nn/transformer/relative_attention.py +++ b/src/fairseq2/nn/transformer/relative_attention.py @@ -7,19 +7,20 @@ from __future__ import annotations import math -from typing import Optional, Tuple, final +from typing import final import torch import torch.nn as nn from torch import Tensor from torch.nn import Module, Parameter from torch.nn.functional import pad +from typing_extensions import override from fairseq2.nn.padding import PaddingMask from fairseq2.nn.projection import Linear -from fairseq2.nn.transformer.attention import SDPA, create_default_sdpa +from fairseq2.nn.transformer.attention import SDPA, make_default_sdpa from fairseq2.nn.transformer.attention_mask import AttentionMask, CustomAttentionMask -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -41,9 +42,9 @@ def __init__( num_heads: int, pos_encoding: RelativePositionalEncoding, *, - inner_sdpa: Optional[SDPA] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + inner_sdpa: SDPA | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -88,7 +89,7 @@ def __init__( if inner_sdpa is not None: self.inner_sdpa = inner_sdpa else: - self.inner_sdpa = create_default_sdpa() + self.inner_sdpa = make_default_sdpa() self.reset_parameters() @@ -102,12 +103,12 @@ def forward( self, seqs: Tensor, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, + attn_mask: AttentionMask | None = None, needs_weights: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: q = seqs k = keys @@ -203,7 +204,7 @@ def __init__( encoding_dim: int, max_seq_len: int, *, - device: Optional[Device] = None, + device: Device | None = None, ) -> None: """ :param encoding_dim: diff --git a/src/fairseq2/nn/transformer/residual.py b/src/fairseq2/nn/transformer/residual.py new file mode 100644 index 000000000..a510a3a10 --- /dev/null +++ b/src/fairseq2/nn/transformer/residual.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import final + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Module, Parameter +from typing_extensions import override + +from fairseq2.typing import DataType, Device + + +class ResidualConnect(Module, ABC): + """Represents a residual connection within a Transformer layer.""" + + @abstractmethod + def forward(self, seqs: Tensor, residual: Tensor) -> Tensor: + """ + :param seqs: The sequences output by a module such as a multi-head + attention layer or a feed-forward network. *Shape:* :math:`(N,S,M)`, + where :math:`N` is the batch size, :math:`S` is the sequence length, + and :math:`M` is the dimensionality of the model. + :param residual: The input sequences to the module. *Shape:* Same as + ``seqs``. + + :returns: The output sequences with residuals applied. *Shape:* Same as + ``seqs``. + """ + + +@final +class StandardResidualConnect(ResidualConnect): + """Sums inputs and outputs of a Transformer module.""" + + @override + def forward(self, seqs: Tensor, residual: Tensor) -> Tensor: + return seqs + residual + + +@final +class ScaledResidualConnect(ResidualConnect): + """ + Scales residuals by a constant factor before adding them to the output of a + Transformer module. + """ + + scale: float + + def __init__(self, scale: float) -> None: + """ + :param scale: The scale factor. + """ + self.scale = scale + + @override + def forward(self, seqs: Tensor, residual: Tensor) -> Tensor: + residual = self.scale * residual + + return seqs + residual + + +@final +class NormFormerResidualConnect(ResidualConnect): + """ + Scales residuals by a learned factor before adding them to the output of a + feed-forward network as described in + :cite:t:`https://doi.org/10.48550/arxiv.2110.09456`. + """ + + scale_proj: Parameter + + def __init__( + self, + model_dim: int, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param model_dim: The dimensionality of the model. + """ + super().__init__() + + self.scale_proj = Parameter( + torch.empty((model_dim,), device=device, dtype=dtype) + ) + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + nn.init.ones_(self.scale_proj) + + @override + def forward(self, seqs: Tensor, residual: Tensor) -> Tensor: + residual = self.scale_proj * residual + + return seqs + residual + + +@final +class DropPathResidualConnect(ResidualConnect): + """ + Drops entire sequences from module outputs before adding residuals which + effectively results in stochastic depth as described in section 3 of + :cite:t:`https://doi.org/10.48550/arxiv.1603.09382`. + + .. note:: + This implementation is mostly adapted from Ross Wightman's ``drop_path`` + function in timm. + """ + + drop_p: float + scale_by_keep: bool + + def __init__(self, drop_p: float, scale_by_keep: bool = True) -> None: + """ + :param drop_p: The probability of dropping sequences from module outputs. + :param scale_by_keep: If ``True``, non-dropped sequences will be scaled + by the keep probability (i.e. ``1 - drop_p``) as in EfficientNet. + """ + super().__init__() + + self.drop_p = drop_p + self.scale_by_keep = scale_by_keep + + @override + def forward(self, seqs: Tensor, residual: Tensor) -> Tensor: + if not self.training or self.drop_p == 0.0: + return seqs + residual + + shape = [seqs.size(0)] + [1] * (seqs.ndim - 1) + + keep_p = 1.0 - self.drop_p + + # (N) + drop_mask = torch.rand(shape, device=seqs.device, dtype=seqs.dtype) + keep_p + + drop_mask.floor_() # binarize + + if self.scale_by_keep: + seqs = seqs / keep_p + + return (seqs * drop_mask) + residual + + def extra_repr(self) -> str: + return f"drop_p={self.drop_p}, scale_by_keep={self.scale_by_keep}" diff --git a/src/fairseq2/nn/transformer/shaw_attention.py b/src/fairseq2/nn/transformer/shaw_attention.py index 611ad6910..6761257d1 100644 --- a/src/fairseq2/nn/transformer/shaw_attention.py +++ b/src/fairseq2/nn/transformer/shaw_attention.py @@ -6,17 +6,19 @@ from __future__ import annotations -from typing import Optional, Tuple, final +from typing import final import torch import torch.nn as nn from torch import Tensor +from typing_extensions import override +from fairseq2.error import InternalError from fairseq2.nn.embedding import StandardEmbedding from fairseq2.nn.padding import PaddingMask -from fairseq2.nn.transformer.attention import SDPA, create_default_sdpa +from fairseq2.nn.transformer.attention import SDPA, make_default_sdpa from fairseq2.nn.transformer.attention_mask import AttentionMask, CustomAttentionMask -from fairseq2.typing import DataType, Device, override +from fairseq2.typing import DataType, Device @final @@ -29,7 +31,7 @@ class ShawRelativePositionSDPA(SDPA): max_left_rel_pos: int max_right_rel_pos: int rel_k_embed: StandardEmbedding - rel_v_embed: Optional[StandardEmbedding] + rel_v_embed: StandardEmbedding | None inner_sdpa: SDPA def __init__( @@ -38,11 +40,11 @@ def __init__( num_heads: int, max_left_rel_pos: int, *, - max_right_rel_pos: Optional[int] = None, + max_right_rel_pos: int | None = None, use_rel_pos_values: bool = False, - inner_sdpa: Optional[SDPA] = None, - device: Optional[Device] = None, - dtype: Optional[DataType] = None, + inner_sdpa: SDPA | None = None, + device: Device | None = None, + dtype: DataType | None = None, ) -> None: """ :param model_dim: @@ -96,19 +98,19 @@ def __init__( if inner_sdpa is not None: self.inner_sdpa = inner_sdpa else: - self.inner_sdpa = create_default_sdpa() + self.inner_sdpa = make_default_sdpa() @override def forward( self, seqs: Tensor, keys: Tensor, - key_padding_mask: Optional[PaddingMask], + key_padding_mask: PaddingMask | None, values: Tensor, *, - attn_mask: Optional[AttentionMask] = None, + attn_mask: AttentionMask | None = None, needs_weights: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> tuple[Tensor, Tensor | None]: q_len = seqs.size(2) # (S_kv, S_kv) @@ -144,7 +146,8 @@ def forward( ) if self.rel_v_embed is not None: - assert attn_weights is not None + if attn_weights is None: + raise InternalError("`attn_weights` is `None`.") # (S_kv, S_kv, V_h) rel_pos_values = self.rel_v_embed(rel_indices) diff --git a/src/fairseq2/nn/utils/gradient.py b/src/fairseq2/nn/utils/gradient.py index d188c6b77..cc4eb7c64 100644 --- a/src/fairseq2/nn/utils/gradient.py +++ b/src/fairseq2/nn/utils/gradient.py @@ -7,7 +7,7 @@ from __future__ import annotations import logging -from typing import Any, Optional, Tuple, Union +from typing import Any import torch from torch import Tensor @@ -17,9 +17,7 @@ from torch.nn.utils import clip_grad_norm_ # type: ignore[attr-defined] from fairseq2.gang import Gang, all_sum -from fairseq2.logging import get_log_writer - -log = get_log_writer(__name__) +from fairseq2.logging import log def normalize_gradients(module: Module, gang: Gang, num_targets: int) -> None: @@ -38,7 +36,7 @@ def normalize_gradients(module: Module, gang: Gang, num_targets: int) -> None: scale_gradients(module, gang.size / total_num_targets) -def scale_gradients(module: Module, value: Union[float, Tensor]) -> None: +def scale_gradients(module: Module, value: float | Tensor) -> None: """Scale gradients of ``module`` by ``value``. :param module: @@ -70,20 +68,20 @@ class _GradientScaleFunction(Function): def forward(ctx: Any, x: Tensor, scale: float) -> Tensor: # type: ignore[override] if not x.dtype.is_floating_point: raise TypeError( - f"`x` must be a float tensor, but is of type `{x.dtype}` instead." + f"`x` must be a float tensor, but is a `{x.dtype}` tensor instead." ) ctx.scale = scale - return x.detach().clone().requires_grad_(True) + return x.detach().clone() @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, None]: # type: ignore[override] + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, None]: # type: ignore[override] return grad_output * ctx.scale, None def clip_gradient_norm( - module: Module, max_norm: Optional[float], norm_type: float = 2.0 + module: Module, max_norm: float | None, norm_type: float = 2.0 ) -> Tensor: """Clip the gradient norms ``module``. diff --git a/src/fairseq2/nn/utils/mask.py b/src/fairseq2/nn/utils/mask.py index 9232c8ad9..e9d05d9c2 100644 --- a/src/fairseq2/nn/utils/mask.py +++ b/src/fairseq2/nn/utils/mask.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Protocol import torch from torch import Tensor @@ -15,7 +15,7 @@ from fairseq2.typing import DataType, Device -def to_float_mask(mask: Tensor, dtype: Optional[DataType] = None) -> Tensor: +def to_float_mask(mask: Tensor, dtype: DataType | None = None) -> Tensor: """Convert a boolean mask to a float mask. :param mask: @@ -30,35 +30,50 @@ def to_float_mask(mask: Tensor, dtype: Optional[DataType] = None) -> Tensor: return torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -torch.inf) +class RowMaskFactory(Protocol): + def __call__( + self, + shape: tuple[int, int], + span_len: int, + max_mask_prob: float, + row_lens: Tensor | None = None, + min_num_spans: int = 0, + device: Device | None = None, + ) -> Tensor | None: + """Compute a random row mask of the specified shape. + + :param shape: + The shape of the mask. + :param span_len: + The length of each mask span. + :param max_mask_prob: + The maximum probability of masking an element in a row. + :param row_lens: + The length of each row. *Shape:* :math:`(R)`, where :math:`R` is the + number of rows. + :param min_num_spans: + The minimum number of mask spans per row. + :param device: + The device on which to initialize the mask. + + :returns: + The boolean row mask. *:Shape:* ``shape``. + """ + + def compute_row_mask( - shape: Tuple[int, int], + shape: tuple[int, int], span_len: int, max_mask_prob: float, - row_lens: Optional[Tensor] = None, + row_lens: Tensor | None = None, min_num_spans: int = 0, - device: Optional[Device] = None, -) -> Optional[Tensor]: - """Compute a random row mask of the specified shape. - - :param shape: - The shape of the mask. - :param span_len: - The length of each mask span. - :param max_mask_prob: - The maximum probability of masking an element in a row. Note that, due - to mask span overlap, the effective probability will be lower. The - implementation also guarantees that there will be always at least one - unmasked element in each row. - :param row_lens: - The length of each row. *Shape:* :math:`(R)`, where :math:`R` is the - number of rows. - :param min_num_spans: - The minimum number of mask spans per row. - :param device: - The device on which to initialize the mask. - - :returns: - The boolean row mask. *:Shape:* ``shape``. + device: Device | None = None, +) -> Tensor | None: + """Implements the :class:`RowMaskFactory` protocol. + + Note that, due to mask span overlap, the effective mask probability will be + lower than ``max_mask_prob``. The implementation also guarantees that there + will be always at least one unmasked element in each row. """ num_rows, max_row_len = shape @@ -93,7 +108,7 @@ def compute_row_mask( def _compute_mask_spans( row_lens: Tensor, span_len: int, max_mask_prob: float, min_num_spans: int -) -> Optional[Tensor]: +) -> Tensor | None: """Compute random mask spans of the specified shape.""" device, dtype = row_lens.device, row_lens.dtype @@ -165,7 +180,17 @@ def _generate_mask(indices: Tensor, max_row_len: int) -> Tensor: # (N, min(M x L)) # We randomly pick `min_num_masked` masked elements from each row, which # effectively unmasks the remaining elements. - indices = torch.multinomial(float_mask, num_samples=min_num_masked) + # + # We first make a tensor of random values and 0.001 to it to ensure the + # minimum value is larger than 0. Then we multiply it with the float_mask so + # that all the 0 values in `float_mask` are still 0 but the non-zero values + # have a random value assigned to them. Then we select the top-k values, + # which would be basically a subset of non-zero values `float_mask`. + random_values = torch.rand_like(float_mask) + 0.001 + + random_values = random_values * float_mask + + _, indices = torch.topk(random_values, k=min_num_masked, dim=1, sorted=False) # (N, S) # Now we construct the actual boolean mask which has the same number of diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 4a5c13fea..fba6672df 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -7,23 +7,10 @@ from __future__ import annotations import re +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from itertools import chain -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Mapping, - Optional, - Protocol, - Sequence, - Set, - Tuple, - runtime_checkable, -) +from typing import Protocol, runtime_checkable import torch from torch import Tensor @@ -31,11 +18,8 @@ from torch.nn.utils import remove_weight_norm # type: ignore[attr-defined] from fairseq2.gang import Gang -from fairseq2.logging import get_log_writer +from fairseq2.logging import log from fairseq2.typing import CPU, Device -from fairseq2.utils.rng import temporary_manual_seed - -log = get_log_writer(__name__) @runtime_checkable @@ -44,34 +28,20 @@ def reset_parameters(self) -> None: """Reset the parameters and buffers of the module.""" -def reset_parameters( - module: Module, *, recurse: bool = True, seed: Optional[int] = None -) -> None: +def reset_parameters(module: Module, *, recurse: bool = True) -> None: """Reset the parameters and buffers of ``module``. :param module: The module to reset. :param recurse: If ``True``, resets the parameters and buffers of descendant modules. - :param seed: - The random number generator seed to use during parameter initialization. """ def reset(name: str, m: Module) -> None: if isinstance(m, ModuleWithParameter): m.reset_parameters() - if seed is None: - devices: List[Device] = [] - else: - device = infer_device(module, recurse=recurse) - if device.type == "meta": - devices = [] - else: - devices = [CPU, device] - - with temporary_manual_seed(devices, seed): - visit_module(module, reset, recurse=recurse) + visit_module(module, reset, recurse=recurse) @runtime_checkable @@ -102,7 +72,7 @@ def visit_module( *, recurse: bool = True, post_order: bool = True, - memo: Optional[Set[Module]] = None, + memo: set[Module] | None = None, ) -> None: """Run ``visitor`` on ``module``. @@ -125,24 +95,27 @@ def visit_module( visitor(name, m) -def to_device(module: Module, device: Device, *, seed: Optional[int] = None) -> None: +def to_device(module: Module, device: Device) -> None: """Move the parameters and buffers of ``module`` to ``device``. :param module: The module to move. :param device: The target device of the parameters and buffers. - :param seed: - The random number generator seed to use during parameter initialization - if ``module`` is on the meta device. """ - modules: List[Tuple[Module, Device]] = [] + modules: list[tuple[Module, Device]] = [] for name, m in _get_named_modules(module, prefix="module", post_order=True): if m is None: continue - module_device = infer_device(m, name, recurse=False) + try: + module_device = infer_device(m, recurse=False) + except ValueError as ex: + raise ValueError( + f"The device of `{name}` is not valid. See the nested exception for details." + ) from ex + if module_device == device: continue @@ -151,21 +124,15 @@ def to_device(module: Module, device: Device, *, seed: Optional[int] = None) -> if not modules: return - memo: Dict[Tensor, Tensor] = {} - - if seed is None or device.type == "meta": - devices = [] - else: - devices = [CPU, device] + memo: dict[Tensor, Tensor] = {} - with temporary_manual_seed(devices, seed): - for m, module_device in modules: - if module_device.type != "meta": - apply_to_parameters(m, lambda t: t.to(device), recurse=False, memo=memo) - else: - to_empty(m, device, recurse=False, memo=memo) + for m, module_device in modules: + if module_device.type != "meta": + apply_to_parameters(m, lambda t: t.to(device), recurse=False, memo=memo) + else: + to_empty(m, device, recurse=False, memo=memo) - reset_parameters(m, recurse=False) + reset_parameters(m, recurse=False) def to_empty( @@ -173,7 +140,7 @@ def to_empty( device: Device, *, recurse: bool = True, - memo: Optional[Dict[Tensor, Tensor]] = None, + memo: dict[Tensor, Tensor] | None = None, ) -> None: """Move the parameters and buffers of ``module`` to ``device`` without copying storage. @@ -243,7 +210,7 @@ def collect_tensors(m: Module) -> None: # Do not memoize. No need anyways, and would also break the sync between the # traversed tensors and the iterator. - apply_to_parameters(target_module, lambda _: next(it), recurse=True, no_memo=True) + apply_to_parameters(target_module, lambda _: next(it), no_memo=True) def apply_to_parameters( @@ -251,7 +218,7 @@ def apply_to_parameters( fn: Callable[[Tensor], Tensor], *, recurse: bool = True, - memo: Optional[Dict[Tensor, Tensor]] = None, + memo: dict[Tensor, Tensor] | None = None, no_memo: bool = False, ) -> None: """Apply ``fn`` to the parameters and buffers of ``module``. @@ -321,7 +288,7 @@ def call_fn( setattr(module, buffer_name, call_fn(buffer)) -def freeze_parameters(module: Optional[Module], value: bool = True) -> None: +def freeze_parameters(module: Module | None, value: bool = True) -> None: """Set if ``module`` and its descendant modules should stop learning.""" if module is None: return @@ -331,7 +298,7 @@ def freeze_parameters(module: Optional[Module], value: bool = True) -> None: def select_parameters( module: Module, names: Sequence[str], *, exclude: bool = False -) -> Iterable[Tuple[str, Parameter]]: +) -> Iterable[tuple[str, Parameter]]: """Select the parameters of ``module`` and its descendant modules whose names match ``names``. @@ -372,9 +339,7 @@ def remove(name: str, m: Module) -> None: visit_module(module, remove, recurse=recurse) -def infer_device( - module: Module, name: str = "module", *, recurse: bool = True -) -> Device: +def infer_device(module: Module, *, recurse: bool = True) -> Device: """Infer the device on which ``module``'s parameters and buffers reside. :param module: @@ -399,10 +364,10 @@ def infer_device( if len(devices) == 1: return devices.pop() - s = ", ".join(sorted(f"'{d.type}'" for d in devices)) + s = ", ".join(sorted(f"`{d.type}`" for d in devices)) raise ValueError( - f"All parameters and buffers of `{name}` must be on the same device, but they are on {s}." + f"All parameters and buffers of `module` must be on the same device, but they are on {s}." ) @@ -427,7 +392,7 @@ def broadcast_module( warned = False - memo: Set[Tensor] = set() + memo: set[Tensor] = set() tensors = [] @@ -466,16 +431,17 @@ def broadcast_module( _broadcast_coalesced(pg, tensors, bucket_size, source_rank) -def load_state_dict(module: Module, state_dict: Mapping[str, Any]) -> None: +def load_state_dict( + module: Module, state_dict: Mapping[str, object], strict: bool = True +) -> None: """Copy parameters and buffers from ``state_dict`` into ``module`` and its descendant modules. - This implementation internally calls :meth:`Module.load_state_dict()` with - ``strict`` set to ``True``, and also enforces that ``state_dict`` does not - contain any keys corresponding to descendants that are set to ``None`` via - :meth:`Module.register_module()`. + This implementation internally calls :meth:`Module.load_state_dict()`, and also enforces that + ``state_dict`` does not contain any keys corresponding to descendants that are set to ``None`` + via :meth:`Module.register_module()`. """ - module.load_state_dict(state_dict, strict=True) + module.load_state_dict(state_dict, strict=strict) unexpected_keys = [] @@ -490,19 +456,23 @@ def load_state_dict(module: Module, state_dict: Mapping[str, Any]) -> None: unexpected_keys.append(key) if unexpected_keys: - raise RuntimeError( - f"Unexpected key(s) in `state_dict`: {', '.join(unexpected_keys)}" + unexpected_keys.sort() + + s = ", ".join(unexpected_keys) + + raise ValueError( + f"`state_dict` must not contain the following unexpected key(s): {s}" ) def _get_named_modules( - module: Optional[Module], + module: Module | None, *, prefix: str = "", recurse: bool = True, post_order: bool = False, - memo: Optional[Set[Module]] = None, -) -> Iterator[Tuple[str, Optional[Module]]]: + memo: set[Module] | None = None, +) -> Iterator[tuple[str, Module | None]]: if module is None: yield prefix, None @@ -542,7 +512,7 @@ def _get_named_modules( yield prefix, module -@dataclass +@dataclass(kw_only=True) class ModuleSizeInfo: """Holds the size information of a module.""" diff --git a/src/fairseq2/optim/__init__.py b/src/fairseq2/optim/__init__.py index d29f85a8e..b10b5b334 100644 --- a/src/fairseq2/optim/__init__.py +++ b/src/fairseq2/optim/__init__.py @@ -6,6 +6,14 @@ from __future__ import annotations +from fairseq2.optim.adamw import ADAMW_OPTIMIZER as ADAMW_OPTIMIZER from fairseq2.optim.adamw import AdamW as AdamW +from fairseq2.optim.adamw import AdamWConfig as AdamWConfig +from fairseq2.optim.adamw import AdamWHandler as AdamWHandler from fairseq2.optim.dynamic_loss_scaler import DynamicLossScaler as DynamicLossScaler from fairseq2.optim.dynamic_loss_scaler import LossScaleResult as LossScaleResult +from fairseq2.optim.handler import OptimizerHandler as OptimizerHandler +from fairseq2.optim.handler import OptimizerNotFoundError as OptimizerNotFoundError +from fairseq2.optim.optimizer import AbstractOptimizer as AbstractOptimizer +from fairseq2.optim.optimizer import ParameterCollection as ParameterCollection +from fairseq2.optim.static import create_optimizer as create_optimizer diff --git a/src/fairseq2/optim/adamw.py b/src/fairseq2/optim/adamw.py index 1f6cf8ee2..23e97ffda 100644 --- a/src/fairseq2/optim/adamw.py +++ b/src/fairseq2/optim/adamw.py @@ -6,15 +6,20 @@ from __future__ import annotations +from dataclasses import dataclass from itertools import chain -from typing import Any, Dict, Iterable, List, Literal, Tuple, Union, cast, final +from typing import Any, Final, Literal, cast, final import torch from torch import Tensor +from torch.optim import Optimizer from torch.optim.adamw import adamw # type: ignore[attr-defined] +from typing_extensions import override -from fairseq2.optim.optimizer import AbstractOptimizer -from fairseq2.typing import override +from fairseq2.error import NotSupportedError +from fairseq2.optim.handler import OptimizerHandler +from fairseq2.optim.optimizer import AbstractOptimizer, ParameterCollection +from fairseq2.typing import safe_cast @final @@ -28,10 +33,10 @@ class AdamW(AbstractOptimizer): def __init__( self, - params: Union[Iterable[Tensor], Iterable[Dict[str, Any]]], + params: ParameterCollection, *, lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, amsgrad: bool = False, @@ -56,8 +61,7 @@ def __init__( :param amsgrad: If ``True``, uses the AMSGrad variant. :param maximize: - If ``True``, maximizes the parameters based on the objective, - instead of minimizing. + If ``True``, maximizes the parameters instead of minimizing. :param capturable: If ``True``, it is safe to capture this instance in a CUDA graph. :param differentiable: @@ -66,10 +70,9 @@ def __init__( The implementation variant. See :class:`torch.optim.AdamW` for details. :param use_fp32: - If ``True``, stores the optimizer state (e.g. momentum) in single - precision (i.e. ``torch.float32``) and, during a ``step()`` call, - converts gradients on-the-fly to single precision for better - numerical stability for low-precision training. + If ``True``, stores the optimizer state in single precision and + converts gradients on-the-fly to single precision for numerical + stability. """ defaults = { "lr": lr, @@ -88,20 +91,20 @@ def __init__( if impl == "fused": if differentiable: - raise RuntimeError( + raise NotSupportedError( "`fused` implementation does not support `differentiable`." ) for pg in self.param_groups: for p in pg["params"]: if not torch.is_floating_point(p) or p.device.type != "cuda": - raise RuntimeError( + raise NotSupportedError( "`fused` implementation requires all parameters to be float CUDA tensors." ) self._step_supports_amp_scaling = True - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: super().load_state_dict(state_dict) state_keys = ["step", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"] @@ -148,12 +151,12 @@ def _do_step(self) -> None: for pg in self.param_groups: use_fp32: bool = pg["use_fp32"] - params_with_grad: List[Tensor] = [] - grads: List[Tensor] = [] - steps: List[Tensor] = [] - exp_avgs: List[Tensor] = [] - exp_avg_sqs: List[Tensor] = [] - max_exp_avg_sqs: List[Tensor] = [] + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + steps: list[Tensor] = [] + exp_avgs: list[Tensor] = [] + exp_avg_sqs: list[Tensor] = [] + max_exp_avg_sqs: list[Tensor] = [] amsgrad = pg["amsgrad"] beta1, beta2 = pg["betas"] @@ -171,7 +174,7 @@ def _do_step(self) -> None: amsgrad, ) - kwargs: Dict[str, Any] = {} + kwargs: dict[str, object] = {} if pg["differentiable"]: kwargs["differentiable"] = True @@ -222,14 +225,14 @@ def _do_step(self) -> None: def _init_param( self, param: Tensor, - param_group: Dict[str, Any], + param_group: dict[str, object], use_fp32: bool, - params_with_grad: List[Tensor], - grads: List[Tensor], - steps: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], + params_with_grad: list[Tensor], + grads: list[Tensor], + steps: list[Tensor], + exp_avgs: list[Tensor], + exp_avg_sqs: list[Tensor], + max_exp_avg_sqs: list[Tensor], amsgrad: bool, ) -> None: grad = param.grad @@ -237,9 +240,9 @@ def _init_param( return if grad.is_sparse: - raise RuntimeError("`AdamW` does not support sparse gradients.") + raise NotSupportedError("`AdamW` does not support sparse gradients.") - state = cast(Dict[str, Tensor], self.state[param]) # type: ignore[index] + state = cast(dict[str, Tensor], self.state[param]) # type: ignore[index] if use_fp32: if param.dtype != torch.float32: @@ -277,3 +280,67 @@ def _init_param( if amsgrad: max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + +ADAMW_OPTIMIZER: Final = "adamw" + + +@dataclass(kw_only=True) +class AdamWConfig: + lr: float = 1e-3 + """The learning rate.""" + + betas: tuple[float, float] = (0.9, 0.999) + """The coefficients used for computing running averages of gradient and its + square.""" + + eps: float = 1e-8 + """The term added to the denominator to improve numerical stability.""" + + weight_decay: float = 0.0 + """The weight decay coefficient.""" + + amsgrad: bool = False + """If ``True``, uses the AMSGrad variant.""" + + maximize: bool = False + """If ``True``, maximizes the parameters instead of minimizing.""" + + capturable: bool = False + """If ``True``, it is safe to capture this instance in a CUDA graph.""" + + differentiable: bool = False + """If ``True``, runs the optimizer step under autograd.""" + + impl: Literal["auto", "foreach", "fused", "naive"] = "auto" + """The implementation variant. See :class:`torch.optim.AdamW` for details.""" + + use_fp32: bool = False + """If ``True``, stores the optimizer state in single precision and converts + gradients on-the-fly to single precision for numerical stability.""" + + +@final +class AdamWHandler(OptimizerHandler): + @override + def create(self, params: ParameterCollection, config: object) -> Optimizer: + config = safe_cast("config", config, AdamWConfig) + + return AdamW( + params, + lr=config.lr, + betas=config.betas, + eps=config.eps, + weight_decay=config.weight_decay, + amsgrad=config.amsgrad, + maximize=config.maximize, + capturable=config.capturable, + differentiable=config.differentiable, + impl=config.impl, + use_fp32=config.use_fp32, + ) + + @property + @override + def config_kls(self) -> type: + return AdamWConfig diff --git a/src/fairseq2/optim/dynamic_loss_scaler.py b/src/fairseq2/optim/dynamic_loss_scaler.py index 81c8eda14..feb7df2f3 100644 --- a/src/fairseq2/optim/dynamic_loss_scaler.py +++ b/src/fairseq2/optim/dynamic_loss_scaler.py @@ -7,8 +7,11 @@ from __future__ import annotations import math +import warnings +from collections.abc import Callable, Mapping from dataclasses import dataclass -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union, cast, final +from typing import cast, final +from warnings import catch_warnings import torch from torch import Tensor @@ -16,11 +19,10 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.optim import Optimizer +from fairseq2.error import InvalidOperationError from fairseq2.gang import Gang -from fairseq2.logging import get_log_writer -from fairseq2.typing import Device, override - -log = get_log_writer(__name__) +from fairseq2.logging import log +from fairseq2.typing import Device @final @@ -31,10 +33,10 @@ class DynamicLossScaler: _optimizer: Optimizer _scale_window: int _min_scale: float - _is_enabled: bool + _enabled: bool - # TODO: consolidate into `GradScaler` once we cease support for PT2.2 - _grad_scaler: Union[GradScaler, ShardedGradScaler] + # compat: consolidate into `GradScaler` once we cease support for PT2.2 + _grad_scaler: GradScaler | ShardedGradScaler def __init__( self, @@ -44,7 +46,7 @@ def __init__( sharded: bool = True, init_scale: float = 2.0**15, scale_factor: float = 2.0, - scale_window: Optional[int] = None, + scale_window: int | None = None, min_scale: float = 0.0, gradient_accumulation: int = 1, enabled: bool = True, @@ -97,13 +99,16 @@ def __init__( scale_window = 1 if not enabled or not sharded or gang.size == 1: - self._grad_scaler = _InternalGradScaler( - init_scale=init_scale, - growth_factor=scale_factor, - backoff_factor=1 / scale_factor, - growth_interval=scale_window, - enabled=enabled, - ) + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress deprecation warning. + + self._grad_scaler = _InternalGradScaler( + init_scale=init_scale, + growth_factor=scale_factor, + backoff_factor=1 / scale_factor, + growth_interval=scale_window, + enabled=enabled, + ) else: if not supports_manual_gradient_scaling(optimizer): raise ValueError( @@ -124,22 +129,32 @@ def __init__( self._optimizer = optimizer self._scale_window = scale_window self._min_scale = min_scale - self._is_enabled = enabled + self._enabled = enabled - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, object]: return {"grad_scaler": self._grad_scaler.state_dict()} - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: + try: + gs_state_dict = state_dict["grad_scaler"] + except KeyError: + raise ValueError("`state_dict` must contain a 'grad_scaler' key.") from None + + if not isinstance(gs_state_dict, dict): + raise TypeError( + f"`state_dict['grad_scaler']` must be of type `dict`, but is of type `{type(gs_state_dict)}` instead." + ) + try: - self._grad_scaler.load_state_dict(state_dict["grad_scaler"]) - except KeyError as ex: + self._grad_scaler.load_state_dict(gs_state_dict) + except (RuntimeError, ValueError) as ex: raise ValueError( - "`state_dict` must contain the state of the internal `GradScaler`." + f"`state_dict['grad_scaler']` is not a valid `{type(self._grad_scaler)}` state. See the nested exception for details." ) from ex def run_optimizer_step( - self, step_nr: int, closure: Optional[Callable[[], float]] = None - ) -> Tuple[Optional[float], LossScaleResult]: + self, step_nr: int, closure: Callable[[], float] | None = None + ) -> tuple[float | None, LossScaleResult]: """Perform a single optimization step. :param step_nr: @@ -193,8 +208,8 @@ def _are_close(a: float, b: float) -> bool: def unscale_gradients_(self) -> None: """Unscale the associated optimizer's gradients by the current scale.""" - if not supports_manual_gradient_scaling(self._optimizer): - raise RuntimeError( + if self._enabled and not supports_manual_gradient_scaling(self._optimizer): + raise InvalidOperationError( "`optimizer` must support manual gradient scaling via `torch.cuda.amp.GradScaler`, but supports only implicit scaling in its step function (i.e. `_step_supports_amp_scaling == True`)." ) @@ -211,10 +226,9 @@ def get_scale(self) -> float: @property def is_enabled(self) -> bool: """``True`` if the loss scaling is enabled.""" - return self._is_enabled + return self._enabled -@final @dataclass(frozen=True) class LossScaleResult: """Holds the result of a loss scale operation.""" @@ -233,17 +247,19 @@ class LossScaleResult: def supports_manual_gradient_scaling(optimizer: Optimizer) -> bool: - """Return ``True`` if ``optimizer`` supports manual gradient scaling via - ``torch.cuda.amp.GradScaler``.""" + """ + Returns ``True`` if ``optimizer`` supports manual gradient scaling via + ``torch.cuda.amp.GradScaler``. + """ return not getattr(optimizer, "_step_supports_amp_scaling", False) # An ugly hack. class _InternalGradScaler(GradScaler): - @override + # override def _unscale_grads_( self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, _: bool - ) -> Dict[Device, Tensor]: + ) -> dict[Device, Tensor]: # `GradScaler` artificially limits fp16 gradients only to optimizers # that natively support AMP. Here, we hijack `_unscale_grads_()` and # always pass `allow_fp16=True` to the real function. diff --git a/src/fairseq2/optim/handler.py b/src/fairseq2/optim/handler.py new file mode 100644 index 000000000..1bbaa1a06 --- /dev/null +++ b/src/fairseq2/optim/handler.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from torch.optim import Optimizer + +from fairseq2.optim.optimizer import ParameterCollection + + +class OptimizerHandler(ABC): + @abstractmethod + def create(self, params: ParameterCollection, config: object) -> Optimizer: + ... + + @property + @abstractmethod + def config_kls(self) -> type: + ... + + +class OptimizerNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known optimizer.") + + self.name = name diff --git a/src/fairseq2/optim/lr_scheduler.py b/src/fairseq2/optim/lr_scheduler.py deleted file mode 100644 index a5e909a10..000000000 --- a/src/fairseq2/optim/lr_scheduler.py +++ /dev/null @@ -1,542 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import math -import warnings -from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Tuple, Union, final - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from typing_extensions import TypeAlias - -from fairseq2.typing import override - -LRScheduler: TypeAlias = _LRScheduler - - -def get_effective_lr(scheduler: LRScheduler) -> float: - """Return the effective learning rate computed by ``scheduler``.""" - return scheduler.get_last_lr()[0] - - -class AbstractLRScheduler(ABC, LRScheduler): - """Provides a skeletal implementation of :class:`LRScheduler`.""" - - @final - @override - def get_lr(self) -> List[float]: # type: ignore[override] - if not self._get_lr_called_within_step: # type: ignore[attr-defined] - warnings.warn( - "To get the last learning rate computed by the scheduler, use `get_last_lr()`." - ) - - return self._compute_lrs() - - @abstractmethod - def _compute_lrs(self) -> List[float]: - """Compute the learning rate of each parameter group.""" - - -@final -class NoopLR(AbstractLRScheduler): - """Represents a no-op learning rate schedule.""" - - def __init__(self, optimizer: Optimizer, *, last_epoch: int = -1) -> None: - super().__init__(optimizer, last_epoch) - - @override - def _compute_lrs(self) -> List[float]: - return self.base_lrs - - -@final -class CosineAnnealingLR(AbstractLRScheduler): - """Represents the learning rate schedule described in - :cite:t:`https://doi.org/10.48550/arxiv.1608.03983`. - - **During warmup:** - - .. math:: - \\eta_t = \\eta_{base} \\frac{t}{T_{warmup}} - - **After warmup:** - - .. math:: - \\eta_t = \\eta_{final}^i + \\frac{1}{2} (\\eta_{base}^i - \\eta_{final}^i) (1 + \\text{cos}(\\pi \\frac{t_{i}}{T_{i}})) - - where :math:`i` is the number of the current annealing cycle, :math:`t_i` is - the number of steps taken since the last restart, and :math:`T_i` is the - total number of steps within the :math:`i`-th cycle (i.e. *length* of the - cycle). - - *Cosine Annealing* is a type of learning rate schedule that has the effect - of starting with a large learning rate that is relatively rapidly decreased - to a minimum value before being increased rapidly again. - - Please refer to the paper to learn more about the details. - - In addition to the original schedule, this implementation also supports a - warmup phase where the learning rate is linearly increased for the first - :math:`T_{warmup}` training steps to the base learning rate. - - .. note:: - This scheduler is not chainable. - """ - - _cycle_len: int - _cycle_mul: float - _num_warmup_steps: int - _lr_mul: float - _start_lrs: Sequence[float] - _final_lrs: Sequence[float] - - def __init__( - self, - optimizer: Optimizer, - cycle_len: int, - num_warmup_steps: int, - *, - cycle_mul: float = 1.0, - lr_mul: float = 1.0, - start_lr: Union[float, Sequence[float]] = 0.0, - final_lr: Union[float, Sequence[float]] = 0.0, - last_epoch: int = -1, - ) -> None: - """ - :param optimizer: - The associated optimizer. - :param cycle_len: - The number of steps within the first cycle. - :param num_warmup_steps: - The number of warmup steps. - :param cycle_mul: - The factor to grow the length of each cycle. - :param lr_mul: - The factor to scale the base and final learning rate at the end of - each cycle. - :param start_lr: - The initial warmup learning rate of all parameter groups, or of each - parameter group respectively. - :param final_lr: - The final learning rate of all parameter groups, or of each - parameter group respectively, at the end of the first cycle. - :param last_epoch: - The index of the last epoch. - """ - self._cycle_len = cycle_len - self._cycle_mul = cycle_mul - self._num_warmup_steps = num_warmup_steps - self._lr_mul = lr_mul - - self._start_lrs = _get_per_param_group(optimizer, "start_lr", start_lr) - self._final_lrs = _get_per_param_group(optimizer, "final_lr", final_lr) - - super().__init__(optimizer, last_epoch) - - @override - def _compute_lrs(self) -> List[float]: - base_lrs = self.base_lrs - - # Linearly increase the learning rate to its base value during warmup. - if self.last_epoch < self._num_warmup_steps: - c = self.last_epoch / self._num_warmup_steps - - return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] - - curr_step = self.last_epoch - self._num_warmup_steps - - # When each cycle has equal length, the computation is straightforward. - if self._cycle_mul == 1.0: - cycle_nr = curr_step // self._cycle_len - - cycle_len = self._cycle_len - - # The position of the step within the cycle. - cycle_pos = curr_step - (cycle_nr * cycle_len) - - # Otherwise, it becomes a bit trickier. We have to treat the cycles as - # a geometric series to find out the number, length, and offset of the - # current cycle. - else: - mul = self._cycle_mul - - # Solve the equation \sum_{i=0}^{n} len(cycle_i) + x = step for n. - cycle_nr = int(math.log(1 - curr_step / self._cycle_len * (1 - mul), mul)) - - cycle_len = int(mul**cycle_nr * self._cycle_len) - - # Compute the sum of the lengths of the first `cycle_nr` cycles - # (i.e. geometric series) which corresponds to the beginning offset - # of the current cycle. - cycle_offset = int((1 - mul**cycle_nr) / (1 - mul) * self._cycle_len) - - # The position of the step within the cycle. - cycle_pos = curr_step - cycle_offset - - lr_mul = self._lr_mul**cycle_nr - - c = math.cos(math.pi * cycle_pos / cycle_len) - - min_lrs, max_lrs = self._final_lrs, base_lrs - - return [self._cycle_lr(mn, mx, lr_mul, c) for mn, mx in zip(min_lrs, max_lrs)] - - def _cycle_lr(self, min_lr: float, max_lr: float, lr_mul: float, c: float) -> float: - min_lr *= lr_mul - max_lr *= lr_mul - - return min_lr + 0.5 * (max_lr - min_lr) * (1 + c) - - -@final -class MyleLR(AbstractLRScheduler): - """Represents a scaled version of :class:`NoamLR` that preserves the base - learning rate of the associated optimizer. - - .. math:: - \\eta_t = \\eta_{base} \\min(\\sqrt{\\frac{T_{warmup}}{t}}, \\frac{t}{T_{warmup}}) - - Essentially, this is Noam learning rate schedule scaled by the square root - of the number of warmup steps. It was originally proposed and implemented by - Myle Ott in fairseq under the name ``InverseSquareRootLR``. - - It corresponds to increasing the learning rate linearly for the first - :math:`T_{warmup}` training steps to the base learning rate, and decreasing - it thereafter proportionally to the inverse square root of the step number. - - .. note:: - This scheduler is not chainable. - """ - - _num_warmup_steps: int - _start_lrs: Sequence[float] - - def __init__( - self, - optimizer: Optimizer, - num_warmup_steps: int, - *, - start_lr: Union[float, Sequence[float]] = 0.0, - last_epoch: int = -1, - ) -> None: - """ - :param optimizer: - The associated optimizer. - :param num_warmup_steps: - The number of warmup steps. - :param start_lr: - The initial warmup learning rate of all parameter groups, or of each - parameter group respectively. - :param last_epoch: - The index of the last epoch. - """ - if num_warmup_steps == 0: - raise ValueError("`num_warmup_steps` must be greater than 0.") - - self._num_warmup_steps = num_warmup_steps - - self._start_lrs = _get_per_param_group(optimizer, "start_lr", start_lr) - - super().__init__(optimizer, last_epoch) - - @override - def _compute_lrs(self) -> List[float]: - base_lrs = self.base_lrs - - # Linearly increase the learning rate to its base value during warmup. - if self.last_epoch < self._num_warmup_steps: - c = self.last_epoch / self._num_warmup_steps - - return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] - - # After the warmup, decay the learning rate proportional to the inverse - # square root of the step number. - c = (self._num_warmup_steps / self.last_epoch) ** 0.5 - - return [b * c for b in base_lrs] - - -@final -class NoamLR(AbstractLRScheduler): - """Represents the learning rate schedule described in Section 5.3 of - :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. - - .. math:: - \\eta_t = \\eta_{base} \\min(\\frac{1}{\\sqrt{t}}, \\frac{t}{T_{warmup}} \\frac{1}{\\sqrt{T_{warmup}}}) - - This corresponds to increasing the learning rate linearly for the first - :math:`T_{warmup}` training steps, and decreasing it thereafter - proportionally to the inverse square root of the step number. In the paper, - the authors use the square root of the dimensionality of the model as - :math:`\\eta_{base}`. - - This scheduler is commonly referred to as Noam, after the second author of - the paper, Noam Shazeer. - - .. note:: - This scheduler is not chainable. - """ - - _num_warmup_steps: int - - def __init__( - self, - optimizer: Optimizer, - num_warmup_steps: int, - *, - last_epoch: int = -1, - ) -> None: - """ - :param optimizer: - The associated optimizer. - :param num_warmup_steps: - The number of warmup steps. - :param last_epoch: - The index of the last epoch. - """ - self._num_warmup_steps = num_warmup_steps - - super().__init__(optimizer, last_epoch) - - @override - def _compute_lrs(self) -> List[float]: - # Linearly increase the learning rate during warmup. - if self.last_epoch < self._num_warmup_steps: - c = self.last_epoch * self._num_warmup_steps**-1.5 - - # No warmup requested, decay from the base learning rate. - elif self.last_epoch == 0: - c = 1.0 - - # After the warmup, decay the learning rate proportional to the inverse - # square root of the step number. - else: - c = self.last_epoch**-0.5 - - return [b * c for b in self.base_lrs] - - -@final -class PolynomialDecayLR(AbstractLRScheduler): - """Represents the polynomial decay learning rate schedule. - - **During warmup:** - - .. math:: - \\eta_t = \\eta_{base} \\frac{t}{T_{warmup}} - - **After warmup:** - - .. math:: - \\eta_t = \\eta_{final} + (\\eta_{base} - \\eta_{final}) (\\frac{T - t}{T - T_{warmup}})^{p} - - This corresponds to increasing the learning rate linearly for the first - :math:`T_{warmup}` training steps to the base learning rate, and decreasing - it thereafter for :math:`T - T_{warmup}` steps to the final learning rate - using a polynomial of degree :math:`p`. - - .. note:: - This scheduler is not chainable. - """ - - _num_steps: int - _num_warmup_steps: int - _power: float - _start_lrs: Sequence[float] - _final_lrs: Sequence[float] - - def __init__( - self, - optimizer: Optimizer, - num_steps: int, - num_warmup_steps: int, - *, - power: float = 1.0, - start_lr: Union[float, Sequence[float]] = 0.0, - final_lr: Union[float, Sequence[float]] = 0.0, - last_epoch: int = -1, - ) -> None: - """ - :param optimizer: - The associated optimizer. - :param num_steps: - The total number of steps, including warmup, over which to decay the - learning rate. - :param num_warmup_steps: - The number of warmup steps. - :param power: - The exponent of the polynomial used for decay. - :param start_lr: - The initial warmup learning rate of all parameter groups, or of each - parameter group respectively. - :param final_lr: - The final learning rate of all parameter groups, or of each - parameter group respectively. - :param last_epoch: - The index of the last epoch. - """ - if num_warmup_steps >= num_steps: - raise ValueError( - f"`num_warmup_steps` must be less than `num_steps` ({num_steps}), but is {num_warmup_steps} instead." - ) - - self._num_steps = num_steps - self._num_warmup_steps = num_warmup_steps - self._power = power - - self._start_lrs = _get_per_param_group(optimizer, "start_lr", start_lr) - self._final_lrs = _get_per_param_group(optimizer, "final_lr", final_lr) - - super().__init__(optimizer, last_epoch) - - @override - def _compute_lrs(self) -> List[float]: - base_lrs = self.base_lrs - - # The decay is already complete, return the final learning rate. - if self.last_epoch >= self._num_steps: - return [f for f in self._final_lrs] - - # Linearly increase the learning rate to its base value during warmup. - if self.last_epoch < self._num_warmup_steps: - c = self.last_epoch / self._num_warmup_steps - - return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] - - # After the warmup, decay the learning rate to its final value. - r = self._num_steps - self.last_epoch - t = self._num_steps - self._num_warmup_steps - - c = (r / t) ** self._power - - return [f + (b - f) * c for b, f in zip(base_lrs, self._final_lrs)] - - -@final -class TriStageLR(AbstractLRScheduler): - """Represents the tri-stage learning rate schedule as described in Section - 3.2 of :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. - - The learning rate schedule employs three stages: - - - The warm-up stage where the learning rate is linearly increased to its - maximum value (i.e. `base_lr`) - - The hold stage where the learning rate is kept constant at its maximum - value. - - The decay stage where the learning rate is exponentially decayed to its - final value. - - .. note:: - This scheduler is not chainable. - """ - - _num_steps: int - _start_lr_scales: Sequence[float] - _final_lr_scales: Sequence[float] - _start_lrs: Optional[Sequence[float]] - _final_lrs: Optional[Sequence[float]] - _num_stage1_steps: int - _num_stage2_steps: int - _num_stage3_steps: int - - def __init__( - self, - optimizer: Optimizer, - num_steps: int, - stage_ratio: Tuple[float, float, float], - *, - start_lr_scale: Union[float, Sequence[float]] = 0.01, - final_lr_scale: Union[float, Sequence[float]] = 0.01, - last_epoch: int = -1, - ) -> None: - """ - :param optimizer: - The associated optimizer. - :param num_steps: - The total number of steps over which to adjust the learning rate. - :param stage_ratio: - The ratios of warmup, hold, and decay stages. Must add up to 1. - :param start_lr_scale: - The scale of the initial warm-up learning rate. - :param final_lr_scale: - The scale of the final learning rate. - """ - if not math.isclose((s := sum(stage_ratio)), 1.0): - raise ValueError( - f"The sum of `stage_ratio` values must be 1.0, but is {s} instead." - ) - - self._num_steps = num_steps - - self._start_lr_scales = _get_per_param_group( - optimizer, "start_lr", start_lr_scale - ) - self._final_lr_scales = _get_per_param_group( - optimizer, "final_lr", final_lr_scale - ) - - self._start_lrs = None - self._final_lrs = None - - self._num_stage1_steps = int(stage_ratio[0] * num_steps) - self._num_stage2_steps = int(stage_ratio[1] * num_steps) - self._num_stage3_steps = int(stage_ratio[2] * num_steps) - - super().__init__(optimizer, last_epoch) - - @override - def _compute_lrs(self) -> List[float]: - base_lrs = self.base_lrs - - # Due to `LRScheduler`'s constructor quirks, we delay the initialization - # of `start_lrs` and `final_lrs` to here. - if self._start_lrs is None: - self._start_lrs = [s * b for s, b in zip(self._start_lr_scales, base_lrs)] - - if self._final_lrs is None: - self._final_lrs = [s * b for s, b in zip(self._final_lr_scales, base_lrs)] - - num_steps = self.last_epoch - - # Linearly increase the learning rate to its base value during warmup. - if num_steps < self._num_stage1_steps: - c = num_steps / self._num_stage1_steps - - return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] - - num_steps -= self._num_stage1_steps - - # Keep the learning rate constant during second stage. - if num_steps < self._num_stage2_steps: - return list(base_lrs) - - num_steps -= self._num_stage2_steps - - if num_steps < self._num_stage3_steps: - c = num_steps / self._num_stage3_steps - - return [b * math.exp(math.log(f) * c) for b, f in zip(base_lrs, self._final_lr_scales)] # fmt: skip - - return list(self._final_lrs) - - -def _get_per_param_group( - optimizer: Optimizer, name: str, value: Union[float, Sequence[float]] -) -> Sequence[float]: - num_param_groups = len(optimizer.param_groups) - - if isinstance(value, float): - return [value] * num_param_groups - - if len(value) != num_param_groups: - raise ValueError( - f"The length of `{name}` must be equal to the number of parameter groups ({num_param_groups}), but is {len(value)} instead." - ) - - return value diff --git a/src/fairseq2/optim/lr_scheduler/__init__.py b/src/fairseq2/optim/lr_scheduler/__init__.py new file mode 100644 index 000000000..2858caa8e --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.optim.lr_scheduler.cosine_annealing import ( + COSINE_ANNEALING_LR as COSINE_ANNEALING_LR, +) +from fairseq2.optim.lr_scheduler.cosine_annealing import ( + CosineAnnealingLR as CosineAnnealingLR, +) +from fairseq2.optim.lr_scheduler.cosine_annealing import ( + CosineAnnealingLRConfig as CosineAnnealingLRConfig, +) +from fairseq2.optim.lr_scheduler.cosine_annealing import ( + CosineAnnealingLRHandler as CosineAnnealingLRHandler, +) +from fairseq2.optim.lr_scheduler.handler import LRSchedulerHandler as LRSchedulerHandler +from fairseq2.optim.lr_scheduler.handler import ( + LRSchedulerNotFoundError as LRSchedulerNotFoundError, +) +from fairseq2.optim.lr_scheduler.lr_scheduler import ( + AbstractLRScheduler as AbstractLRScheduler, +) +from fairseq2.optim.lr_scheduler.lr_scheduler import LRScheduler as LRScheduler +from fairseq2.optim.lr_scheduler.lr_scheduler import NoopLR as NoopLR +from fairseq2.optim.lr_scheduler.lr_scheduler import ( + get_effective_lr as get_effective_lr, +) +from fairseq2.optim.lr_scheduler.myle import MYLE_LR as MYLE_LR +from fairseq2.optim.lr_scheduler.myle import MyleLR as MyleLR +from fairseq2.optim.lr_scheduler.myle import MyleLRConfig as MyleLRConfig +from fairseq2.optim.lr_scheduler.myle import MyleLRHandler as MyleLRHandler +from fairseq2.optim.lr_scheduler.noam import NOAM_LR as NOAM_LR +from fairseq2.optim.lr_scheduler.noam import NoamLR as NoamLR +from fairseq2.optim.lr_scheduler.noam import NoamLRConfig as NoamLRConfig +from fairseq2.optim.lr_scheduler.noam import NoamLRHandler as NoamLRHandler +from fairseq2.optim.lr_scheduler.polynomial_decay import ( + POLYNOMIAL_DECAY_LR as POLYNOMIAL_DECAY_LR, +) +from fairseq2.optim.lr_scheduler.polynomial_decay import ( + PolynomialDecayLR as PolynomialDecayLR, +) +from fairseq2.optim.lr_scheduler.polynomial_decay import ( + PolynomialDecayLRConfig as PolynomialDecayLRConfig, +) +from fairseq2.optim.lr_scheduler.polynomial_decay import ( + PolynomialDecayLRHandler as PolynomialDecayLRHandler, +) +from fairseq2.optim.lr_scheduler.static import ( + create_lr_scheduler as create_lr_scheduler, +) +from fairseq2.optim.lr_scheduler.tri_stage import TRI_STAGE_LR as TRI_STAGE_LR +from fairseq2.optim.lr_scheduler.tri_stage import TriStageLR as TriStageLR +from fairseq2.optim.lr_scheduler.tri_stage import TriStageLRConfig as TriStageLRConfig +from fairseq2.optim.lr_scheduler.tri_stage import TriStageLRHandler as TriStageLRHandler diff --git a/src/fairseq2/optim/lr_scheduler/cosine_annealing.py b/src/fairseq2/optim/lr_scheduler/cosine_annealing.py new file mode 100644 index 000000000..6fe1c2871 --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/cosine_annealing.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Final, final + +from torch.optim import Optimizer +from typing_extensions import override + +from fairseq2.logging import log +from fairseq2.optim.lr_scheduler.handler import LRSchedulerHandler +from fairseq2.optim.lr_scheduler.lr_scheduler import ( + AbstractLRScheduler, + LRScheduler, + get_per_param_group, +) +from fairseq2.typing import safe_cast + + +@final +class CosineAnnealingLR(AbstractLRScheduler): + """Represents the learning rate schedule described in + :cite:t:`https://doi.org/10.48550/arxiv.1608.03983`. + + **During warmup:** + + .. math:: + \\eta_t = \\eta_{base} \\frac{t}{T_{warmup}} + + **After warmup:** + + .. math:: + \\eta_t = \\eta_{final}^i + \\frac{1}{2} (\\eta_{base}^i - \\eta_{final}^i) (1 + \\text{cos}(\\pi \\frac{t_{i}}{T_{i}})) + + where :math:`i` is the number of the current annealing cycle, :math:`t_i` is + the number of steps taken since the last restart, and :math:`T_i` is the + total number of steps within the :math:`i`-th cycle (i.e. *length* of the + cycle). + + *Cosine Annealing* is a type of learning rate schedule that has the effect + of starting with a large learning rate that is relatively rapidly decreased + to a minimum value before being increased rapidly again. + + Please refer to the paper to learn more about the details. + + In addition to the original schedule, this implementation also supports a + warmup phase where the learning rate is linearly increased for the first + :math:`T_{warmup}` training steps to the base learning rate. + + .. note:: + This scheduler is not chainable. + """ + + _cycle_len: int + _cycle_mul: float + _num_warmup_steps: int + _lr_mul: float + _start_lrs: Sequence[float] + _final_lrs: Sequence[float] + + def __init__( + self, + optimizer: Optimizer, + cycle_len: int, + num_warmup_steps: int, + *, + cycle_mul: float = 1.0, + lr_mul: float = 1.0, + start_lr: float | Sequence[float] = 0.0, + final_lr: float | Sequence[float] = 0.0, + last_epoch: int = -1, + ) -> None: + """ + :param optimizer: + The optimizer to associate. + :param cycle_len: + The number of steps within the first cycle. + :param num_warmup_steps: + The number of warmup steps. + :param cycle_mul: + The factor to grow the length of each cycle. + :param lr_mul: + The factor to scale the base and final learning rate at the end of + each cycle. + :param start_lr: + The initial warmup learning rate of all parameter groups, or of each + parameter group respectively. + :param final_lr: + The final learning rate of all parameter groups, or of each + parameter group respectively, at the end of the first cycle. + :param last_epoch: + The index of the last epoch. + """ + self._cycle_len = cycle_len + self._cycle_mul = cycle_mul + self._num_warmup_steps = num_warmup_steps + self._lr_mul = lr_mul + + self._start_lrs = get_per_param_group(optimizer, "start_lr", start_lr) + self._final_lrs = get_per_param_group(optimizer, "final_lr", final_lr) + + super().__init__(optimizer, last_epoch) + + @override + def _compute_lrs(self) -> list[float]: + base_lrs = self.base_lrs + + # Linearly increase the learning rate to its base value during warmup. + if self.last_epoch < self._num_warmup_steps: + c = self.last_epoch / self._num_warmup_steps + + return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] + + curr_step = self.last_epoch - self._num_warmup_steps + + # When each cycle has equal length, the computation is straightforward. + if self._cycle_mul == 1.0: + cycle_nr = curr_step // self._cycle_len + + cycle_len = self._cycle_len + + # The position of the step within the cycle. + cycle_pos = curr_step - (cycle_nr * cycle_len) + + # Otherwise, it becomes a bit trickier. We have to treat the cycles as + # a geometric series to find out the number, length, and offset of the + # current cycle. + else: + mul = self._cycle_mul + + # Solve the equation \sum_{i=0}^{n} len(cycle_i) + x = step for n. + cycle_nr = int(math.log(1 - curr_step / self._cycle_len * (1 - mul), mul)) + + cycle_len = int(mul**cycle_nr * self._cycle_len) + + # Compute the sum of the lengths of the first `cycle_nr` cycles + # (i.e. geometric series) which corresponds to the beginning offset + # of the current cycle. + cycle_offset = int((1 - mul**cycle_nr) / (1 - mul) * self._cycle_len) + + # The position of the step within the cycle. + cycle_pos = curr_step - cycle_offset + + lr_mul = self._lr_mul**cycle_nr + + c = math.cos(math.pi * cycle_pos / cycle_len) + + min_lrs, max_lrs = self._final_lrs, base_lrs + + return [self._cycle_lr(mn, mx, lr_mul, c) for mn, mx in zip(min_lrs, max_lrs)] + + def _cycle_lr(self, min_lr: float, max_lr: float, lr_mul: float, c: float) -> float: + min_lr *= lr_mul + max_lr *= lr_mul + + return min_lr + 0.5 * (max_lr - min_lr) * (1 + c) + + +COSINE_ANNEALING_LR: Final = "cosine-annealing" + + +@dataclass(kw_only=True) +class CosineAnnealingLRConfig: + cycle_len: int | None = None + """The number of steps within the first cycle. If ``None``, will be set to + ``num_steps - num_warmup_steps``.""" + + num_warmup_steps: int = 0 + """The number of warmup steps.""" + + cycle_mul: float = 1.0 + """The factor to grow the length of each cycle.""" + + lr_mul: float = 1.0 + """The factor to scale the base and final learning rate at the end of each + cycle.""" + + start_lr: float = 0.0 + """The initial warmup learning rate.""" + + final_lr: float | None = None + """The final learning rate. If ``None``, :attr:`final_lr_scale` will be used.""" + + final_lr_scale: float | None = 0.2 + """ + The optimizer learning rate will be scaled by this value to determine the + final learning rate. If ``None``, :attr:`final_lr` will be used. + """ + + +@final +class CosineAnnealingLRHandler(LRSchedulerHandler): + @override + def create( + self, optimizer: Optimizer, config: object, num_steps: int | None + ) -> LRScheduler: + config = safe_cast("config", config, CosineAnnealingLRConfig) + + if config.cycle_len is None: + if num_steps is None: + raise ValueError( + "`config.cycle_len` must be specified when `num_steps` is not specified." + ) + + cycle_len = num_steps - config.num_warmup_steps + else: + cycle_len = config.cycle_len + + if config.final_lr is not None and config.final_lr_scale is not None: + raise ValueError( + "`config.final_lr` and `config.final_lr_scale` must not be specified at the same time." + ) + + try: + lr = optimizer.param_groups[0]["lr"] + except (IndexError, KeyError): + raise ValueError( + "`optimizer` does not have a parameter group with an assigned learning rate." + ) from None + + if config.final_lr_scale is not None: + final_lr = lr * config.final_lr_scale + elif config.final_lr is not None: + final_lr = config.final_lr + else: + raise ValueError( + "Either `config.final_lr` or `config.final_lr_scale` must be specified." + ) + + if final_lr > lr: + log.warning("The final learning rate ({}) is greater than the optimizer learning rate ({}). This means your learning rate will increase over the course of the training.", final_lr, lr) # fmt: skip + + return CosineAnnealingLR( + optimizer, + cycle_len, + config.num_warmup_steps, + cycle_mul=config.cycle_mul, + lr_mul=config.lr_mul, + start_lr=config.start_lr, + final_lr=final_lr, + ) + + @property + @override + def requires_num_steps(self) -> bool: + return False + + @property + @override + def config_kls(self) -> type: + return CosineAnnealingLRConfig diff --git a/src/fairseq2/optim/lr_scheduler/handler.py b/src/fairseq2/optim/lr_scheduler/handler.py new file mode 100644 index 000000000..a11eda3e0 --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/handler.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from torch.optim import Optimizer + +from fairseq2.optim.lr_scheduler.lr_scheduler import LRScheduler + + +class LRSchedulerHandler(ABC): + @abstractmethod + def create( + self, optimizer: Optimizer, config: object, num_steps: int | None + ) -> LRScheduler: + ... + + @property + @abstractmethod + def requires_num_steps(self) -> bool: + ... + + @property + @abstractmethod + def config_kls(self) -> type: + ... + + +class LRSchedulerNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known learning rate scheduler.") + + self.name = name diff --git a/src/fairseq2/optim/lr_scheduler/lr_scheduler.py b/src/fairseq2/optim/lr_scheduler/lr_scheduler.py new file mode 100644 index 000000000..da738038f --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/lr_scheduler.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import warnings +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import TypeAlias, final + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from typing_extensions import override + +LRScheduler: TypeAlias = _LRScheduler + + +class AbstractLRScheduler(ABC, LRScheduler): + """Provides a skeletal implementation of :class:`LRScheduler`.""" + + @final + @override + def get_lr(self) -> list[float]: # type: ignore[override] + if not self._get_lr_called_within_step: # type: ignore[attr-defined] + warnings.warn( + "To get the last learning rate computed by the scheduler, use `get_last_lr()`." + ) + + return self._compute_lrs() + + @abstractmethod + def _compute_lrs(self) -> list[float]: + """Compute the learning rate of each parameter group.""" + + +@final +class NoopLR(AbstractLRScheduler): + """Represents a no-op learning rate schedule.""" + + def __init__(self, optimizer: Optimizer, *, last_epoch: int = -1) -> None: + super().__init__(optimizer, last_epoch) + + @override + def _compute_lrs(self) -> list[float]: + return self.base_lrs + + +def get_per_param_group( + optimizer: Optimizer, name: str, value: float | Sequence[float] +) -> Sequence[float]: + num_param_groups = len(optimizer.param_groups) + + if isinstance(value, float): + return [value] * num_param_groups + + if len(value) != num_param_groups: + raise ValueError( + f"The length of `{name}` must be equal to the number of parameter groups ({num_param_groups}), but is {len(value)} instead." + ) + + return value + + +def get_effective_lr(scheduler: LRScheduler) -> float: + """Return the effective learning rate computed by ``scheduler``.""" + return scheduler.get_last_lr()[0] diff --git a/src/fairseq2/optim/lr_scheduler/myle.py b/src/fairseq2/optim/lr_scheduler/myle.py new file mode 100644 index 000000000..562317a3a --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/myle.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Final, final + +from torch.optim import Optimizer +from typing_extensions import override + +from fairseq2.optim.lr_scheduler.handler import LRSchedulerHandler +from fairseq2.optim.lr_scheduler.lr_scheduler import ( + AbstractLRScheduler, + LRScheduler, + get_per_param_group, +) +from fairseq2.typing import safe_cast + + +@final +class MyleLR(AbstractLRScheduler): + """Represents a scaled version of :class:`NoamLR` that preserves the base + learning rate of the associated optimizer. + + .. math:: + \\eta_t = \\eta_{base} \\min(\\sqrt{\\frac{T_{warmup}}{t}}, \\frac{t}{T_{warmup}}) + + Essentially, this is Noam learning rate schedule scaled by the square root + of the number of warmup steps. It was originally proposed and implemented by + Myle Ott in fairseq under the name ``InverseSquareRootLR``. + + It corresponds to increasing the learning rate linearly for the first + :math:`T_{warmup}` training steps to the base learning rate, and decreasing + it thereafter proportionally to the inverse square root of the step number. + + .. note:: + This scheduler is not chainable. + """ + + _num_warmup_steps: int + _start_lrs: Sequence[float] + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + *, + start_lr: float | Sequence[float] = 0.0, + last_epoch: int = -1, + ) -> None: + """ + :param optimizer: + The optimizer to associate. + :param num_warmup_steps: + The number of warmup steps. + :param start_lr: + The initial warmup learning rate of all parameter groups, or of each + parameter group respectively. + :param last_epoch: + The index of the last epoch. + """ + if num_warmup_steps == 0: + raise ValueError("`num_warmup_steps` must be greater than 0.") + + self._num_warmup_steps = num_warmup_steps + + self._start_lrs = get_per_param_group(optimizer, "start_lr", start_lr) + + super().__init__(optimizer, last_epoch) + + @override + def _compute_lrs(self) -> list[float]: + base_lrs = self.base_lrs + + # Linearly increase the learning rate to its base value during warmup. + if self.last_epoch < self._num_warmup_steps: + c = self.last_epoch / self._num_warmup_steps + + return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] + + # After the warmup, decay the learning rate proportional to the inverse + # square root of the step number. + c = (self._num_warmup_steps / self.last_epoch) ** 0.5 + + return [b * c for b in base_lrs] + + +MYLE_LR: Final = "myle" + + +@dataclass(kw_only=True) +class MyleLRConfig: + num_warmup_steps: int = 0 + """The number of warmup steps.""" + + start_lr: float = 0.0 + """The initial warmup learning rate.""" + + +@final +class MyleLRHandler(LRSchedulerHandler): + @override + def create( + self, optimizer: Optimizer, config: object, num_steps: int | None + ) -> LRScheduler: + config = safe_cast("config", config, MyleLRConfig) + + return MyleLR(optimizer, config.num_warmup_steps, start_lr=config.start_lr) + + @property + @override + def requires_num_steps(self) -> bool: + return False + + @property + @override + def config_kls(self) -> type: + return MyleLRConfig diff --git a/src/fairseq2/optim/lr_scheduler/noam.py b/src/fairseq2/optim/lr_scheduler/noam.py new file mode 100644 index 000000000..024878761 --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/noam.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Final, final + +from torch.optim import Optimizer +from typing_extensions import override + +from fairseq2.optim.lr_scheduler.handler import LRSchedulerHandler +from fairseq2.optim.lr_scheduler.lr_scheduler import AbstractLRScheduler, LRScheduler +from fairseq2.typing import safe_cast + + +@final +class NoamLR(AbstractLRScheduler): + """Represents the learning rate schedule described in Section 5.3 of + :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. + + .. math:: + \\eta_t = \\eta_{base} \\min(\\frac{1}{\\sqrt{t}}, \\frac{t}{T_{warmup}} \\frac{1}{\\sqrt{T_{warmup}}}) + + This corresponds to increasing the learning rate linearly for the first + :math:`T_{warmup}` training steps, and decreasing it thereafter + proportionally to the inverse square root of the step number. In the paper, + the authors use the square root of the dimensionality of the model as + :math:`\\eta_{base}`. + + This scheduler is commonly referred to as Noam, after the second author of + the paper, Noam Shazeer. + + .. note:: + This scheduler is not chainable. + """ + + _num_warmup_steps: int + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + *, + last_epoch: int = -1, + ) -> None: + """ + :param optimizer: + The optimizer to associate. + :param num_warmup_steps: + The number of warmup steps. + :param last_epoch: + The index of the last epoch. + """ + self._num_warmup_steps = num_warmup_steps + + super().__init__(optimizer, last_epoch) + + @override + def _compute_lrs(self) -> list[float]: + # Linearly increase the learning rate during warmup. + if self.last_epoch < self._num_warmup_steps: + c = self.last_epoch * self._num_warmup_steps**-1.5 + + # No warmup requested, decay from the base learning rate. + elif self.last_epoch == 0: + c = 1.0 + + # After the warmup, decay the learning rate proportional to the inverse + # square root of the step number. + else: + c = self.last_epoch**-0.5 + + return [b * c for b in self.base_lrs] + + +NOAM_LR: Final = "noam" + + +@dataclass(kw_only=True) +class NoamLRConfig: + num_warmup_steps: int = 0 + """The number of warmup steps.""" + + +@final +class NoamLRHandler(LRSchedulerHandler): + @override + def create( + self, optimizer: Optimizer, config: object, num_steps: int | None + ) -> LRScheduler: + config = safe_cast("config", config, NoamLRConfig) + + return NoamLR(optimizer, config.num_warmup_steps) + + @property + @override + def requires_num_steps(self) -> bool: + return False + + @property + @override + def config_kls(self) -> type: + return NoamLRConfig diff --git a/src/fairseq2/optim/lr_scheduler/polynomial_decay.py b/src/fairseq2/optim/lr_scheduler/polynomial_decay.py new file mode 100644 index 000000000..bc6528498 --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/polynomial_decay.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Final, final + +from torch.optim import Optimizer +from typing_extensions import override + +from fairseq2.optim.lr_scheduler.handler import LRSchedulerHandler +from fairseq2.optim.lr_scheduler.lr_scheduler import ( + AbstractLRScheduler, + LRScheduler, + get_per_param_group, +) +from fairseq2.typing import safe_cast + + +@final +class PolynomialDecayLR(AbstractLRScheduler): + """Represents the polynomial decay learning rate schedule. + + **During warmup:** + + .. math:: + \\eta_t = \\eta_{base} \\frac{t}{T_{warmup}} + + **After warmup:** + + .. math:: + \\eta_t = \\eta_{final} + (\\eta_{base} - \\eta_{final}) (\\frac{T - t}{T - T_{warmup}})^{p} + + This corresponds to increasing the learning rate linearly for the first + :math:`T_{warmup}` training steps to the base learning rate, and decreasing + it thereafter for :math:`T - T_{warmup}` steps to the final learning rate + using a polynomial of degree :math:`p`. + + .. note:: + This scheduler is not chainable. + """ + + _num_steps: int + _num_warmup_steps: int + _power: float + _start_lrs: Sequence[float] + _final_lrs: Sequence[float] + + def __init__( + self, + optimizer: Optimizer, + num_steps: int, + num_warmup_steps: int, + *, + power: float = 1.0, + start_lr: float | Sequence[float] = 0.0, + final_lr: float | Sequence[float] = 0.0, + last_epoch: int = -1, + ) -> None: + """ + :param optimizer: + The optimizer to associate. + :param num_steps: + The total number of steps, including warmup, over which to decay the + learning rate. + :param num_warmup_steps: + The number of warmup steps. + :param power: + The exponent of the polynomial used for decay. + :param start_lr: + The initial warmup learning rate of all parameter groups, or of each + parameter group respectively. + :param final_lr: + The final learning rate of all parameter groups, or of each + parameter group respectively. + :param last_epoch: + The index of the last epoch. + """ + if num_warmup_steps >= num_steps: + raise ValueError( + f"`num_warmup_steps` must be less than `num_steps` ({num_steps}), but is {num_warmup_steps} instead." + ) + + self._num_steps = num_steps + self._num_warmup_steps = num_warmup_steps + self._power = power + + self._start_lrs = get_per_param_group(optimizer, "start_lr", start_lr) + self._final_lrs = get_per_param_group(optimizer, "final_lr", final_lr) + + super().__init__(optimizer, last_epoch) + + @override + def _compute_lrs(self) -> list[float]: + base_lrs = self.base_lrs + + # The decay is already complete, return the final learning rate. + if self.last_epoch >= self._num_steps: + return [f for f in self._final_lrs] + + # Linearly increase the learning rate to its base value during warmup. + if self.last_epoch < self._num_warmup_steps: + c = self.last_epoch / self._num_warmup_steps + + return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] + + # After the warmup, decay the learning rate to its final value. + r = self._num_steps - self.last_epoch + t = self._num_steps - self._num_warmup_steps + + c = (r / t) ** self._power + + return [f + (b - f) * c for b, f in zip(base_lrs, self._final_lrs)] + + +POLYNOMIAL_DECAY_LR: Final = "polynomial-decay" + + +@dataclass(kw_only=True) +class PolynomialDecayLRConfig: + num_warmup_steps: int = 0 + """The number of warmup steps.""" + + power: float = 1.0 + """The exponent of the polynomial used for decay.""" + + start_lr: float = 0.0 + """The initial warmup learning rate.""" + + final_lr: float = 0.0 + """The final learning rate.""" + + +@final +class PolynomialDecayLRHandler(LRSchedulerHandler): + @override + def create( + self, optimizer: Optimizer, config: object, num_steps: int | None + ) -> LRScheduler: + config = safe_cast("config", config, PolynomialDecayLRConfig) + + if num_steps is None: + raise ValueError("`num_steps` must specified.") + + return PolynomialDecayLR( + optimizer, + num_steps, + config.num_warmup_steps, + power=config.power, + start_lr=config.start_lr, + final_lr=config.final_lr, + ) + + @property + @override + def requires_num_steps(self) -> bool: + return True + + @property + @override + def config_kls(self) -> type: + return PolynomialDecayLRConfig diff --git a/src/fairseq2/optim/lr_scheduler/static.py b/src/fairseq2/optim/lr_scheduler/static.py new file mode 100644 index 000000000..9620eb477 --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/static.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from torch.optim import Optimizer + +from fairseq2.context import get_runtime_context +from fairseq2.optim.lr_scheduler.handler import ( + LRSchedulerHandler, + LRSchedulerNotFoundError, +) +from fairseq2.optim.lr_scheduler.lr_scheduler import LRScheduler +from fairseq2.utils.config import process_config +from fairseq2.utils.structured import structure + + +def create_lr_scheduler( + name: str, + optimizer: Optimizer, + config: object = None, + *, + max_num_steps: int | None = None, +) -> LRScheduler: + context = get_runtime_context() + + registry = context.get_registry(LRSchedulerHandler) + + try: + handler = registry.get(name) + except LookupError: + raise LRSchedulerNotFoundError(name) from None + + if config is None: + try: + config = handler.config_kls() + except TypeError: + raise ValueError( + f"`config` must be specified for the '{name}' learning rate scheduler." + ) from None + else: + config = structure(config, handler.config_kls) + + process_config(config) + + return handler.create(optimizer, config, max_num_steps) diff --git a/src/fairseq2/optim/lr_scheduler/tri_stage.py b/src/fairseq2/optim/lr_scheduler/tri_stage.py new file mode 100644 index 000000000..9ac7e9de8 --- /dev/null +++ b/src/fairseq2/optim/lr_scheduler/tri_stage.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Final, final + +from torch.optim import Optimizer +from typing_extensions import override + +from fairseq2.optim.lr_scheduler.handler import LRSchedulerHandler +from fairseq2.optim.lr_scheduler.lr_scheduler import ( + AbstractLRScheduler, + LRScheduler, + get_per_param_group, +) +from fairseq2.typing import safe_cast + + +@final +class TriStageLR(AbstractLRScheduler): + """Represents the tri-stage learning rate schedule as described in Section + 3.2 of :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. + + The learning rate schedule employs three stages: + + - The warm-up stage where the learning rate is linearly increased to its + maximum value (i.e. `base_lr`) + - The hold stage where the learning rate is kept constant at its maximum + value. + - The decay stage where the learning rate is exponentially decayed to its + final value. + + .. note:: + This scheduler is not chainable. + """ + + _num_steps: int + _start_lr_scales: Sequence[float] + _final_lr_scales: Sequence[float] + _start_lrs: Sequence[float] | None + _final_lrs: Sequence[float] | None + _num_stage1_steps: int + _num_stage2_steps: int + _num_stage3_steps: int + + def __init__( + self, + optimizer: Optimizer, + num_steps: int, + stage_ratio: tuple[float, float, float], + *, + start_lr_scale: float | Sequence[float] = 0.01, + final_lr_scale: float | Sequence[float] = 0.01, + last_epoch: int = -1, + ) -> None: + """ + :param optimizer: + The optimizer to associate. + :param num_steps: + The total number of steps over which to adjust the learning rate. + :param stage_ratio: + The ratios of warmup, hold, and decay stages. Must add up to 1. + :param start_lr_scale: + The scale of the initial warm-up learning rate. + :param final_lr_scale: + The scale of the final learning rate. + """ + if not math.isclose((s := sum(stage_ratio)), 1.0): + raise ValueError( + f"The sum of `stage_ratio` values must be 1.0, but is {s} instead." + ) + + self._num_steps = num_steps + + self._start_lr_scales = get_per_param_group( + optimizer, "start_lr", start_lr_scale + ) + self._final_lr_scales = get_per_param_group( + optimizer, "final_lr", final_lr_scale + ) + + self._start_lrs = None + self._final_lrs = None + + self._num_stage1_steps = int(stage_ratio[0] * num_steps) + self._num_stage2_steps = int(stage_ratio[1] * num_steps) + self._num_stage3_steps = int(stage_ratio[2] * num_steps) + + super().__init__(optimizer, last_epoch) + + @override + def _compute_lrs(self) -> list[float]: + base_lrs = self.base_lrs + + # Due to `LRScheduler`'s constructor quirks, we delay the initialization + # of `start_lrs` and `final_lrs` to here. + if self._start_lrs is None: + self._start_lrs = [s * b for s, b in zip(self._start_lr_scales, base_lrs)] + + if self._final_lrs is None: + self._final_lrs = [s * b for s, b in zip(self._final_lr_scales, base_lrs)] + + num_steps = self.last_epoch + + # Linearly increase the learning rate to its base value during warmup. + if num_steps < self._num_stage1_steps: + c = num_steps / self._num_stage1_steps + + return [s + (b - s) * c for b, s in zip(base_lrs, self._start_lrs)] + + num_steps -= self._num_stage1_steps + + # Keep the learning rate constant during second stage. + if num_steps < self._num_stage2_steps: + return list(base_lrs) + + num_steps -= self._num_stage2_steps + + if num_steps < self._num_stage3_steps: + c = num_steps / self._num_stage3_steps + + return [b * math.exp(math.log(f) * c) for b, f in zip(base_lrs, self._final_lr_scales)] # fmt: skip + + return list(self._final_lrs) + + +TRI_STAGE_LR: Final = "tri-stage" + + +@dataclass(kw_only=True) +class TriStageLRConfig: + stage_ratio: tuple[float, float, float] = (0.0, 0.0, 1.0) + """The ratios of warmup, hold, and decay stages. Must add up to 1.""" + + start_lr_scale: float = 0.01 + """The scale of the initial warm-up learning rate.""" + + final_lr_scale: float = 0.01 + """The scale of the final learning rate.""" + + +@final +class TriStageLRHandler(LRSchedulerHandler): + @override + def create( + self, optimizer: Optimizer, config: object, num_steps: int | None + ) -> LRScheduler: + config = safe_cast("config", config, TriStageLRConfig) + + if num_steps is None: + raise ValueError("`num_steps` must specified.") + + return TriStageLR( + optimizer, + num_steps, + config.stage_ratio, + start_lr_scale=config.start_lr_scale, + final_lr_scale=config.final_lr_scale, + ) + + @property + @override + def requires_num_steps(self) -> bool: + return False + + @property + @override + def config_kls(self) -> type: + return TriStageLRConfig diff --git a/src/fairseq2/optim/optimizer.py b/src/fairseq2/optim/optimizer.py index 2cdf40025..0db0ebea4 100644 --- a/src/fairseq2/optim/optimizer.py +++ b/src/fairseq2/optim/optimizer.py @@ -7,19 +7,23 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Callable, Optional, final +from collections.abc import Callable, Iterable +from typing import TypeAlias, final import torch +from torch import Tensor from torch.optim import Optimizer +ParameterCollection: TypeAlias = Iterable[Tensor] | Iterable[dict[str, object]] + class AbstractOptimizer(ABC, Optimizer): """Provides a skeletal implementation of :class:`Optimizer`.""" @final def step( # type: ignore[override] - self, closure: Optional[Callable[[], float]] = None - ) -> Optional[float]: + self, closure: Callable[[], float] | None = None + ) -> float | None: loss = None prev_grad = torch.is_grad_enabled() diff --git a/src/fairseq2/optim/static.py b/src/fairseq2/optim/static.py new file mode 100644 index 000000000..f48da152f --- /dev/null +++ b/src/fairseq2/optim/static.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from torch.nn import Module +from torch.optim import Optimizer + +from fairseq2.context import get_runtime_context +from fairseq2.optim.handler import OptimizerHandler, OptimizerNotFoundError +from fairseq2.optim.optimizer import ParameterCollection +from fairseq2.utils.config import process_config +from fairseq2.utils.structured import structure + + +def create_optimizer( + name: str, params: ParameterCollection | Module, config: object = None +) -> Optimizer: + context = get_runtime_context() + + registry = context.get_registry(OptimizerHandler) + + try: + handler = registry.get(name) + except LookupError: + raise OptimizerNotFoundError(name) from None + + if config is None: + try: + config = handler.config_kls() + except TypeError: + raise ValueError( + f"`config` must be specified for the '{name}' optimizer." + ) from None + else: + config = structure(config, handler.config_kls) + + process_config(config) + + if isinstance(params, Module): + params = params.parameters() + + return handler.create(params, config) diff --git a/src/fairseq2/recipes/__init__.py b/src/fairseq2/recipes/__init__.py index 1124040e8..4bb2db6bb 100644 --- a/src/fairseq2/recipes/__init__.py +++ b/src/fairseq2/recipes/__init__.py @@ -6,45 +6,77 @@ from __future__ import annotations +import os +import sys +from signal import SIG_DFL, SIGINT, raise_signal, signal +from typing import Iterator + +import torch from importlib_metadata import entry_points +from torch.cuda import OutOfMemoryError +from fairseq2.error import ContractError, InternalError +from fairseq2.extensions import ExtensionError, run_extensions +from fairseq2.logging import log from fairseq2.recipes.cli import Cli # isort: split -import os - -from fairseq2.logging import get_log_writer from fairseq2.recipes.assets import _setup_asset_cli +from fairseq2.recipes.hg import _setup_hg_cli from fairseq2.recipes.llama import _setup_llama_cli from fairseq2.recipes.lm import _setup_lm_cli from fairseq2.recipes.logging import setup_basic_logging from fairseq2.recipes.mt import _setup_mt_cli -from fairseq2.recipes.utils.log import exception_logger +from fairseq2.recipes.wav2vec2 import _setup_wav2vec2_cli from fairseq2.recipes.wav2vec2.asr import _setup_wav2vec2_asr_cli -log = get_log_writer(__name__) - def main() -> None: """Run the command line fairseq2 program.""" - from fairseq2 import __version__, setup_extensions + exit_code = 1 + + try: + exit_code = _run() + except KeyboardInterrupt: + log.info("The command has been canceled!") + + signal(SIGINT, SIG_DFL) + + raise_signal(SIGINT) + except OutOfMemoryError: + s = torch.cuda.memory_summary() + + log.exception("CUDA run out of memory. See the logged memory stats.\n{}", s) + except ExtensionError as ex: + log.exception("The '{}' extension has failed to load. See the logged stack trace for details.", ex.entry_point) # fmt: skip + except InternalError: + log.exception("The command has failed with an unexpected internal error. Please file a bug report.") # fmt: skip + except ContractError: + log.exception("The command has failed with an unexpected internal error caused by an extension. See the logged stack trace for details and file a bug report to the corresponding extension author.") # fmt: skip + except Exception: + log.exception("The command has failed with an unexpected error. See the logged stack trace for details.") # fmt: skip + + sys.exit(exit_code) + + +def _run() -> int: + from fairseq2 import __version__, setup_fairseq2 - with exception_logger(log): - setup_basic_logging() + setup_basic_logging() - setup_extensions() + setup_fairseq2() - cli = Cli( - name="fairseq2", - origin_module="fairseq2", - version=__version__, - description="command line interface of fairseq2", - ) + cli = Cli( + name="fairseq2", + origin_module="fairseq2", + version=__version__, + description="command line interface of fairseq2", + ) - _setup_cli(cli) + _setup_cli(cli) - cli() + return cli.run() def _setup_cli(cli: Cli) -> None: @@ -52,25 +84,8 @@ def _setup_cli(cli: Cli) -> None: _setup_lm_cli(cli) _setup_llama_cli(cli) _setup_mt_cli(cli) + _setup_wav2vec2_cli(cli) _setup_wav2vec2_asr_cli(cli) + _setup_hg_cli(cli) - # Set up 3rd party CLI extensions. - for entry_point in entry_points(group="fairseq2.cli"): - try: - setup_cli_extension = entry_point.load() - - setup_cli_extension(cli) - except TypeError: - raise RuntimeError( - f"The entry point '{entry_point.value}' is not a valid fairseq2 CLI setup function." - ) from None - except Exception as ex: - if "FAIRSEQ2_EXTENSION_TRACE" in os.environ: - raise RuntimeError( - f"The CLI setup function at '{entry_point.value}' has failed. See nested exception for details." - ) from ex - - log.warning( - "The CLI setup function at '{}' has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", - entry_point.value, - ) + run_extensions("fairseq2.cli", cli) diff --git a/src/fairseq2/recipes/assets.py b/src/fairseq2/recipes/assets.py index 86777fddb..122dddeb4 100644 --- a/src/fairseq2/recipes/assets.py +++ b/src/fairseq2/recipes/assets.py @@ -6,27 +6,19 @@ from __future__ import annotations -import sys from argparse import ArgumentParser, Namespace from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, final +from typing import final from rich.console import Console from rich.pretty import pretty_repr +from typing_extensions import override -from fairseq2.assets import ( - AssetCard, - AssetNotFoundError, - AssetStore, - default_asset_store, -) -from fairseq2.console import get_console -from fairseq2.data.text import is_tokenizer_card -from fairseq2.datasets import is_dataset_card +from fairseq2.assets import AssetCard, AssetCardNotFoundError, AssetStore +from fairseq2.context import get_runtime_context from fairseq2.logging import get_log_writer -from fairseq2.models import is_model_card from fairseq2.recipes.cli import Cli, CliCommandHandler -from fairseq2.typing import override +from fairseq2.recipes.utils.rich import get_console log = get_log_writer(__name__) @@ -38,31 +30,21 @@ def _setup_asset_cli(cli: Cli) -> None: group.add_command( "list", - ListAssetsCommand(), + ListAssetsHandler(), help="list assets", ) group.add_command( "show", - ShowAssetCommand(), + ShowAssetHandler(), help="show asset", ) @final -class ListAssetsCommand(CliCommandHandler): +class ListAssetsHandler(CliCommandHandler): """Lists assets available in the current Python environment.""" - _asset_store: AssetStore - - def __init__(self, asset_store: Optional[AssetStore] = None) -> None: - """ - :param asset_store: - The asset store from which to retrieve the asset cards. If ``None``, - the default asset store will be used. - """ - self._asset_store = asset_store or default_asset_store - @override def init_parser(self, parser: ArgumentParser) -> None: parser.add_argument( @@ -73,81 +55,102 @@ def init_parser(self, parser: ArgumentParser) -> None: ) @override - def __call__(self, args: Namespace) -> None: - usr_assets = self._retrieve_assets(args, user=True) - glb_assets = self._retrieve_assets(args, user=False) + def run(self, parser: ArgumentParser, args: Namespace) -> int: + context = get_runtime_context() console = get_console() console.print("[green bold]user:") - self._dump_assets(console, usr_assets) + assets = self._retrieve_assets(context.asset_store, args.type, user=True) + + self._dump_assets(console, assets) console.print("[green bold]global:") - self._dump_assets(console, glb_assets) + assets = self._retrieve_assets(context.asset_store, args.type, user=False) + + self._dump_assets(console, assets) + + return 0 + @classmethod def _retrieve_assets( - self, args: Namespace, user: bool - ) -> List[Tuple[str, List[str]]]: - assets: Dict[str, List[str]] = defaultdict(list) + cls, asset_store: AssetStore, asset_type: str, user: bool + ) -> list[tuple[str, list[str]]]: + assets: dict[str, list[str]] = defaultdict(list) - names = self._asset_store.retrieve_names(scope="user" if user else "global") + asset_names = asset_store.retrieve_names(scope="user" if user else "global") - for name in names: + for asset_name in asset_names: try: - card = self._asset_store.retrieve_card( - name, scope="all" if user else "global" + card = asset_store.retrieve_card( + asset_name, scope="all" if user else "global" ) - except AssetNotFoundError: - log.warning("The asset '{}' has an invalid card. Skipping.", name) + except AssetCardNotFoundError: + log.warning("The '{}' asset card is not valid. Skipping.", asset_name) continue - if name[-1] == "@": - name = name[:-1] + if asset_name[-1] == "@": + asset_name = asset_name[:-1] - try: - source = card.metadata["__source__"] - except KeyError: + source = card.metadata.get("__source__", "unknown source") + if not isinstance(source, str): source = "unknown source" - types = [] + asset_types = [] - if args.type == "all" or args.type == "model": - if is_model_card(card): - types.append("model") + if asset_type == "all" or asset_type == "model": + if cls._is_model_card(card): + asset_types.append("model") - if args.type == "all" or args.type == "dataset": - if is_dataset_card(card): - types.append("dataset") + if asset_type == "all" or asset_type == "dataset": + if cls._is_dataset_card(card): + asset_types.append("dataset") - if args.type == "all" or args.type == "tokenizer": - if is_tokenizer_card(card): - types.append("tokenizer") + if asset_type == "all" or asset_type == "tokenizer": + if cls._is_tokenizer_card(card): + asset_types.append("tokenizer") - if args.type == "all" and not types: - types.append("other") + if asset_type == "all" and not asset_types: + asset_types.append("other") - if not types: + if not asset_types: continue source_assets = assets[source] - for t in types: - source_assets.append(f"{t}:{name}") + for t in asset_types: + source_assets.append(f"{t}:{asset_name}") - return [(source, names) for source, names in assets.items()] + output = [] - def _dump_assets( - self, console: Console, assets: List[Tuple[str, List[str]]] - ) -> None: - if assets: - assets.sort(key=lambda a: a[0]) # sort by source. + for source, asset_names in assets.items(): + asset_names.sort() - for source, names in assets: - names.sort(key=lambda n: n[0]) # sort by name. + output.append((source, asset_names)) + + output.sort(key=lambda e: e[0]) # sort by source + + return output + @staticmethod + def _is_model_card(card: AssetCard) -> bool: + return card.field("model_family").exists() + + @staticmethod + def _is_tokenizer_card(card: AssetCard) -> bool: + return card.field("tokenizer_family").exists() + + @staticmethod + def _is_dataset_card(card: AssetCard) -> bool: + return card.field("dataset_family").exists() + + @staticmethod + def _dump_assets(console: Console, assets: list[tuple[str, list[str]]]) -> None: + if assets: + for source, names in assets: console.print(f" [blue bold]{source}") for idx, name in enumerate(names): @@ -159,19 +162,9 @@ def _dump_assets( console.print() -class ShowAssetCommand(CliCommandHandler): +class ShowAssetHandler(CliCommandHandler): """Shows the metadata of an asset.""" - _asset_store: AssetStore - - def __init__(self, asset_store: Optional[AssetStore] = None) -> None: - """ - :param asset_store: - The asset store from which to retrieve the asset cards. If ``None``, - the default asset store will be used. - """ - self._asset_store = asset_store or default_asset_store - @override def init_parser(self, parser: ArgumentParser) -> None: parser.add_argument( @@ -192,32 +185,28 @@ def init_parser(self, parser: ArgumentParser) -> None: parser.add_argument("name", help="name of the asset") @override - def __call__(self, args: Namespace) -> None: - try: - card: Optional[AssetCard] = self._asset_store.retrieve_card( - args.name, envs=args.envs, scope=args.scope - ) - except AssetNotFoundError: - log.error("An asset with the name '{}' cannot be found.", args.asset) + def run(self, parser: ArgumentParser, args: Namespace) -> int: + context = get_runtime_context() - sys.exit(1) + card: AssetCard | None = context.asset_store.retrieve_card( + args.name, envs=args.envs, scope=args.scope + ) while card is not None: self._print_metadata(dict(card.metadata)) card = card.base - def _print_metadata(self, metadata: Dict[str, Any]) -> None: + return 0 + + def _print_metadata(self, metadata: dict[str, object]) -> None: console = get_console() name = metadata.pop("name") console.print(f"[green bold]{name}") - try: - source = metadata.pop("__source__") - except KeyError: - source = "unknown" + source = metadata.pop("__source__", "unknown") items = list(metadata.items()) diff --git a/src/fairseq2/recipes/cli.py b/src/fairseq2/recipes/cli.py index e90e2b47b..fb931771d 100644 --- a/src/fairseq2/recipes/cli.py +++ b/src/fairseq2/recipes/cli.py @@ -8,52 +8,62 @@ import sys from abc import ABC, abstractmethod -from argparse import OPTIONAL, ArgumentParser, Namespace -from copy import deepcopy +from argparse import OPTIONAL, ArgumentParser, BooleanOptionalAction, Namespace +from collections.abc import Hashable, Set from pathlib import Path -from signal import SIGUSR1, signal -from types import FrameType -from typing import ( - Callable, - Dict, - Generic, - Optional, - Protocol, - TypeVar, - final, - runtime_checkable, -) +from typing import TypeVar, final -import yaml from rich.console import Console -from yaml import YAMLError - -from fairseq2.config_registry import ConfigRegistry -from fairseq2.console import get_console, set_console -from fairseq2.logging import get_log_writer -from fairseq2.recipes.logging import setup_basic_logging, setup_logging -from fairseq2.recipes.utils.argparse import BooleanOptionalAction, ConfigAction -from fairseq2.recipes.utils.environment import ( - EnvironmentSetterRegistry, - default_env_setters, +from typing_extensions import override + +from fairseq2.config_registry import ConfigNotFoundError, ConfigProvider +from fairseq2.context import get_runtime_context +from fairseq2.error import AlreadyExistsError, InvalidOperationError, SetupError +from fairseq2.gang import is_torchrun +from fairseq2.logging import log +from fairseq2.recipes.cluster import ( + ClusterError, + ClusterHandler, + ClusterResolver, + UnknownClusterError, ) -from fairseq2.recipes.utils.log import log_config -from fairseq2.recipes.utils.sweep import SweepTagger, default_sweep_tagger -from fairseq2.typing import DataClass, override -from fairseq2.utils.dataclass import FieldError, dump_dataclass, update_dataclass -from fairseq2.utils.value_converter import ValueConverter, default_value_converter - -log = get_log_writer(__name__) +from fairseq2.recipes.logging import DistributedLoggingInitializer, setup_basic_logging +from fairseq2.recipes.runner import ( + ConfigFileNotFoundError, + ConfigReader, + EnvironmentBootstrapper, + Recipe, + RecipeLoader, + RecipeRunner, + StandardConfigReader, + StandardEnvironmentBootstrapper, + StandardRecipeRunner, + SystemSignalHandler, + get_sweep_keys, +) +from fairseq2.recipes.utils.argparse import ConfigAction +from fairseq2.recipes.utils.rich import get_console, set_console +from fairseq2.recipes.utils.sweep_tagger import ( + NoopSweepTagger, + StandardSweepTagger, + SweepFormatError, + SweepFormatPlaceholderError, + SweepTagger, +) +from fairseq2.typing import safe_cast +from fairseq2.utils.file import StandardFileSystem +from fairseq2.utils.structured import StructureError, unstructure +from fairseq2.utils.yaml import YamlDumper, YamlError, dump_yaml, load_yaml class Cli: """Represents the entry point of a command line program.""" _name: str + _description: str | None _origin_module: str _version: str - _description: Optional[str] - _groups: Dict[str, CliGroup] + _groups: dict[str, CliGroup] def __init__( self, @@ -61,59 +71,54 @@ def __init__( origin_module: str, *, version: str, - description: Optional[str] = None, + description: str | None = None, ) -> None: """ - :param name: - The name of the program. - :param origin_module: - The name of the origin Python module of the command line program. - :param version: - The version of the program. - :param description: - The description of the program. + :param name: The name of the program. + :param origin_module: The name of the origin Python module of the + command line program. + :param version: The version of the program. + :param description: The description of the program. """ self._name = name + self._description = description self._origin_module = origin_module self._version = version - self._description = description self._groups = {} def add_group( self, name: str, *, - origin_module: Optional[str] = None, - help: Optional[str] = None, + help: str | None = None, + origin_module: str | None = None, ) -> CliGroup: - """Add a command group. - - :param name: - The name of the command group. - :param origin_module: - The name of origin Python module of the command group. - :param help: - The help text of the command group. - """ - if name in self._groups: - raise ValueError( - f"`name` must be a unique group name, but '{name}' is already registered." - ) + """Add a sub-group.""" + group = self._get_or_add_group(name) - group = CliGroup(name, origin_module or self._origin_module, help=help) + if help is not None: + group._help = help - self._groups[group.name] = group + if origin_module is not None: + group._origin_module = origin_module return group def get_group(self, name: str) -> CliGroup: - """Return the command group of ``name``.""" + """Get a sub-group.""" + return self._get_or_add_group(name) + + def _get_or_add_group(self, name: str) -> CliGroup: try: return self._groups[name] except KeyError: - raise ValueError( - f"`name` must be a registered group name, but is '{name}' instead." - ) from None + pass + + group = CliGroup(name, self._origin_module) + + self._groups[name] = group + + return group def init_parser(self, parser: ArgumentParser) -> None: """Initialize ``parser`` with program-specific arguments.""" @@ -138,13 +143,10 @@ def init_parser(self, parser: ArgumentParser) -> None: group.init_parser(sub_parser) - def __call__(self) -> None: + def run(self) -> int: """Run the program.""" set_console(Console(highlight=False)) - self._run_command() - - def _run_command(self) -> None: parser = ArgumentParser(self._name, description=self._description) self.init_parser(parser) @@ -152,17 +154,20 @@ def _run_command(self) -> None: args = parser.parse_args() if not hasattr(args, "command"): - parser.print_usage(sys.stderr) + parser.error("no command specified") - sys.exit(2) - - args.command(args) + return args.command.run(args) # type: ignore[no-any-return] @property def name(self) -> str: """The name of the program.""" return self._name + @property + def description(self) -> str | None: + """The description of the program.""" + return self._description + @property def origin_module(self) -> str: """The name of the origin Python module of the command line program.""" @@ -173,55 +178,73 @@ def version(self) -> str: """The version of the program.""" return self._version - @property - def description(self) -> Optional[str]: - """The description of the program.""" - return self._description - class CliGroup: """Represents a command group of a command line program.""" _name: str + _groups: dict[str, CliGroup] + _commands: dict[str, CliCommand] + _help: str | None _origin_module: str - _help: Optional[str] - _groups: Dict[str, CliGroup] - _commands: Dict[str, CliCommand] def __init__( self, name: str, origin_module: str, *, - help: Optional[str] = None, + help: str | None = None, ) -> None: self._name = name - self._origin_module = origin_module - self._help = help self._groups = {} self._commands = {} + self._help = help + self._origin_module = origin_module def add_group( self, name: str, *, - origin_module: Optional[str] = None, - help: Optional[str] = None, + help: str | None = None, + origin_module: str | None = None, ) -> CliGroup: - """Add a sub-command group. - - :param name: - The name of the command group. - :param origin_module: - The name of origin Python module of the command group. - :param help: - The help text of the command group. - """ - self._check_name(name) + """Add a sub-group.""" + group = self._get_or_add_group(name) + if group is None: + raise AlreadyExistsError( + f"The command group has already a command named '{name}'." + ) + + if help is not None: + group._help = help - group = CliGroup(name, origin_module or self._origin_module, help=help) + if origin_module is not None: + group._origin_module = origin_module - self._groups[group.name] = group + return group + + def get_group(self, name: str) -> CliGroup: + """Get a sub-group.""" + group = self._get_or_add_group(name) + if group is None: + raise LookupError( + f"The command group does not have a sub-group named '{name}'." + ) from None + + return group + + def _get_or_add_group(self, name: str) -> CliGroup | None: + try: + return self._groups[name] + except KeyError: + pass + + if name in self._commands: + return None + + group = CliGroup(name, self.origin_module) + + self._groups[name] = group return group @@ -230,57 +253,41 @@ def add_command( name: str, handler: CliCommandHandler, *, - origin_module: Optional[str] = None, - help: Optional[str] = None, + help: str | None = None, + origin_module: str | None = None, ) -> CliCommand: """Add a command. - :param name: - The name of the command. - :param handler: - The handler of the command. - :param origin_module: - The name of origin Python module of the command. - :param help: - The help text of the command. + :param name: The name of the command. + :param handler: The handler of the command. + :param origin_module: The name of origin Python module of the command. + :param help: The help text of the command. """ - self._check_name(name) + if name in self._groups: + raise AlreadyExistsError( + f"The command group has already a sub-group named '{name}'." + ) + + if name in self._commands: + raise AlreadyExistsError( + f"The command group has already a command named '{name}'." + ) command = CliCommand( - name, handler, origin_module or self._origin_module, help=help + name, handler, origin_module or self.origin_module, help=help ) self._commands[name] = command return command - def _check_name(self, name: str) -> None: - if name in self._groups: - raise ValueError( - f"`name` must be a unique name among groups and commands, but '{name}' is already registered as a group name." - ) - - if name in self._commands: - raise ValueError( - f"`name` must be a unique name among groups and commands, but '{name}' is already registered as a command name." - ) - - def get_group(self, name: str) -> CliGroup: - """Return the sub-command group of ``name``.""" - try: - return self._groups[name] - except KeyError: - raise ValueError( - f"`name` must be a registered group name, but is '{name}' instead." - ) from None - def get_command(self, name: str) -> CliCommand: """Return the command of ``name``.""" try: return self._commands[name] except KeyError: - raise ValueError( - f"`name` must be a registered command name, but is '{name}' instead." + raise LookupError( + f"The command group does not have a command named '{name}'." ) from None def init_parser(self, parser: ArgumentParser) -> None: @@ -290,7 +297,7 @@ def init_parser(self, parser: ArgumentParser) -> None: for group in self._groups.values(): help = group.help - if self._origin_module != group.origin_module: + if self.origin_module != group.origin_module: s = f"origin: {group.origin_module}" if help: @@ -305,7 +312,7 @@ def init_parser(self, parser: ArgumentParser) -> None: for command in self._commands.values(): help = command.help - if self._origin_module != command.origin_module: + if self.origin_module != command.origin_module: s = f"origin: {command.origin_module}" if help: @@ -324,24 +331,25 @@ def name(self) -> str: """The name of the command group.""" return self._name + @property + def help(self) -> str | None: + """The help text of the command group.""" + return self._help + @property def origin_module(self) -> str: """The name of the origin Python module of the command group.""" return self._origin_module - @property - def help(self) -> Optional[str]: - """The help text of the command group.""" - return self._help - class CliCommand: """Represents a command of a command line program.""" _name: str _handler: CliCommandHandler + _parser: ArgumentParser | None + _help: str | None _origin_module: str - _help: Optional[str] def __init__( self, @@ -349,36 +357,44 @@ def __init__( handler: CliCommandHandler, origin_module: str, *, - help: Optional[str] = None, + help: str | None = None, ) -> None: self._name = name self._handler = handler - self._origin_module = origin_module self._help = help + self._origin_module = origin_module def init_parser(self, parser: ArgumentParser) -> None: """Initialize ``parser`` with command group-specific arguments.""" self._handler.init_parser(parser) - def __call__(self, args: Namespace) -> None: + self._parser = parser + + def run(self, args: Namespace) -> int: """Run the command.""" - self._handler(args) + if self._parser is None: + raise InvalidOperationError("`init_parser()` must be called first.") + + try: + return self._handler.run(self._parser, args) + finally: + self._parser = None @property def name(self) -> str: """The name of the command.""" return self._name + @property + def help(self) -> str | None: + """The help text of the command.""" + return self._help + @property def origin_module(self) -> str: """The name of the origin Python module of the command.""" return self._origin_module - @property - def help(self) -> Optional[str]: - """The help text of the command.""" - return self._help - class CliCommandHandler(ABC): """Represents the handler of a command of a command line program.""" @@ -388,90 +404,51 @@ def init_parser(self, parser: ArgumentParser) -> None: """Initialize ``parser`` with command-specific arguments.""" @abstractmethod - def __call__(self, args: Namespace) -> None: + def run(self, parser: ArgumentParser, args: Namespace) -> int: """Run the command.""" -@runtime_checkable -class Stoppable(Protocol): - """Represents a task that supports graceful stopping.""" - - def request_stop(self) -> None: - ... - - -RecipeConfigT = TypeVar("RecipeConfigT", bound=DataClass) - -RecipeConfigT_contra = TypeVar( - "RecipeConfigT_contra", bound=DataClass, contravariant=True -) - - -class RecipeLoader(Protocol[RecipeConfigT_contra]): - """Loads a recipe.""" - - def __call__( - self, config: RecipeConfigT_contra, output_dir: Path - ) -> Callable[[], None]: - """ - :param name: - The configuration of the recipe. - :param output_dir: - The directory where to store the recipe artifacts. - """ +ConfigT = TypeVar("ConfigT") @final -class RecipeCommandHandler(CliCommandHandler, Generic[RecipeConfigT]): +class RecipeCommandHandler(CliCommandHandler): """Runs a recipe over command line.""" - _loader: RecipeLoader[RecipeConfigT] - _preset_configs: ConfigRegistry[RecipeConfigT] + _loader: RecipeLoader[object] + _preset_configs: ConfigProvider[object] _default_preset: str - _env_setters: EnvironmentSetterRegistry - _value_converter: ValueConverter - _sweep_tagger: SweepTagger - _parser: Optional[ArgumentParser] + _extra_sweep_keys: Set[Hashable] | None def __init__( self, - loader: RecipeLoader[RecipeConfigT], - preset_configs: ConfigRegistry[RecipeConfigT], + loader: RecipeLoader[ConfigT], + preset_configs: ConfigProvider[ConfigT], default_preset: str, *, - env_setters: Optional[EnvironmentSetterRegistry] = None, - value_converter: Optional[ValueConverter] = None, - sweep_tagger: Optional[SweepTagger] = None, + extra_sweep_keys: Set[Hashable] | None = None, ) -> None: """ - :param loader: - The recipe loader. - :param preset_configs: - The registry containing the preset recipe configurations. - :param default_preset: - The name of the default preset. - :param env_setters: - The registry containing cluster-specific :class:`EnvironmentSetter` - instances. - :param value_converter: - The :class:`ValueConverter` instance to use. If ``None``, the - default instance will be used. - :param sweep_tagger: - The :class:`SweepTagger` instance to use. If ``None``, the default - instance will be used. + :param loader: The recipe loader. + :param preset_configs: The registry containing the preset recipe + configurations. + :param default_preset: The name of the default preset. + :param extra_sweep_keys: The recipe specific configuration keys to + include in the sweep directory name. """ - self._loader = loader + + def untyped_loader(config: object, output_dir: Path) -> Recipe: + config = safe_cast("config", config, preset_configs.config_kls) + + return loader(config, output_dir) + + self._loader = untyped_loader self._preset_configs = preset_configs self._default_preset = default_preset - self._env_setters = env_setters or default_env_setters - self._value_converter = value_converter or default_value_converter - self._sweep_tagger = sweep_tagger or default_sweep_tagger - self._parser = None + self._extra_sweep_keys = extra_sweep_keys @override def init_parser(self, parser: ArgumentParser) -> None: - self._parser = parser - parser.add_argument( "--list-presets", action="store_true", @@ -489,6 +466,7 @@ def init_parser(self, parser: ArgumentParser) -> None: dest="config_files", metavar="CONFIG_FILE", type=Path, + action="append", nargs="*", help="yaml configuration file(s)", ) @@ -512,13 +490,14 @@ def init_parser(self, parser: ArgumentParser) -> None: help="do not create sweep directory", ) - clusters = list(self._env_setters.names()) - - clusters.sort() + parser.add_argument( + "--sweep-format", + default="ps_{preset}.ws_{world_size}.{hash}", + help="format of the sweep directory name (default: %(default)s)", + ) parser.add_argument( "--cluster", - choices=["auto"] + clusters, default="auto", help="cluster on which the recipe runs (default: %(default)s)", ) @@ -537,146 +516,147 @@ def init_parser(self, parser: ArgumentParser) -> None: ) @override - def __call__(self, args: Namespace) -> None: - console = get_console() - - setup_basic_logging(debug=args.debug) - - assert self._parser is not None - - # If requested, list the preset configurations and exit. + def run(self, parser: ArgumentParser, args: Namespace) -> int: if args.list_presets: - if self._preset_configs.names(): - console.print("available presets:") - - for preset in self._preset_configs.names(): - if preset == self._default_preset: - console.print(f" - {preset} (default)") - else: - console.print(f" - {preset}") - else: - console.print("no preset configuration found.") + self._print_presets() - sys.exit() + return 0 - # Load the specified preset configuration. - try: - preset_config = self._preset_configs.get(args.preset) - except ValueError: - log.error("'{}' is not a valid preset configuration name. Use `--list-presets` to see the available preset configurations.", args.preset) # fmt: skip - - sys.exit(1) - - config = deepcopy(preset_config) - - # Update the configuration with `--config-file`. - if args.config_files: - for config_file in args.config_files: - try: - with config_file.open() as fp: - config_overrides = yaml.safe_load(fp) - except (OSError, YAMLError): - log.exception("Configuration file '{}' cannot be read.", config_file) # fmt: skip + setup_basic_logging(debug=args.debug) - sys.exit(1) + program = self._create_recipe_program(args) - if not isinstance(config_overrides, dict): - log.error("Configuration file '{}' must contain a dictionary.", config_file) # fmt: skip + try: + program.run(args) + + return 0 + except ConfigNotFoundError as ex: + parser.error(f"argument --preset: '{ex.name}' is not a known preset configuration. Use `--list-presets` to see the available configurations.") # fmt: skip + except ConfigFileNotFoundError as ex: + parser.error(f"argument --config-file: '{ex.config_file}' does not point to a configuration file") # fmt: skip + except MissingOutputDirectoryError: + parser.error("the following arguments are required: output_dir") + except UnknownClusterError as ex: + s = ", ".join(ex.supported_clusters) + + parser.error(f"argument --cluster: '{ex.cluster}' is not a known cluster. Must be one of: auto, none, {s}") # fmt: skip + except SweepFormatPlaceholderError as ex: + s = ", ".join(ex.unknown_keys) + + parser.error(f"argument --sweep-format: must contain only placeholders that correspond to the configuration keys, but contains the following unexpected placeholder(s): {s}") # fmt: skip + except SweepFormatError: + parser.error("argument --sweep-format: must be a non-empty string with brace-enclosed placeholders.") # fmt: skip + except ClusterError as ex: + if ex.cluster == "slurm": + log.exception("'{}' cluster environment cannot be set. See the logged stack trace for details. If you are within an allocated Slurm job (i.e. `salloc`), make sure to run with `srun`. If you want to run without Slurm, use `--cluster none`.", ex.cluster) # fmt: skip + else: + log.exception("'{}' cluster environment cannot be set. See the logged stack trace for details.", ex.cluster) # fmt: skip + except SetupError: + log.exception("The recipe initialization has failed. See the logged stack trace for details.") # fmt: skip + except StructureError: + log.exception("The recipe configuration cannot be parsed. See the logged stack trace for details.") # fmt: skip - sys.exit(1) + return 1 - try: - unknown_fields = update_dataclass( - config, config_overrides, value_converter=self._value_converter - ) - except FieldError as ex: - log.exception("Value of the field '{}' in the configuration file '{}' is invalid.", ex.field_name, config_file) # fmt: skip + def _print_presets(self) -> None: + console = get_console() - sys.exit(1) + names = self._preset_configs.names() - if unknown_fields: - log.error("Following fields in the configuration file '{}' are unknown: {}", config_file, ", ".join(unknown_fields)) # fmt: skip + if names: + console.print("available presets:") - sys.exit(1) + for preset in names: + if preset == self._default_preset: + console.print(f" - {preset} (default)") + else: + console.print(f" - {preset}") + else: + console.print("no preset configuration found.") - # Update the configuration with `--config`. - if args.config_overrides: - try: - unknown_fields = update_dataclass( - config, args.config_overrides, value_converter=self._value_converter - ) - except FieldError as ex: - log.exception("Value of the field '{}' in `--config` is invalid.", ex.field_name) # fmt: skip + def _create_recipe_program(self, args: Namespace) -> RecipeProgram: + file_system = StandardFileSystem() - sys.exit(1) + config_reader = StandardConfigReader( + self._preset_configs, file_system, load_yaml + ) - if unknown_fields: - log.error("Following fields in `--config` are unknown: {}", ", ".join(unknown_fields)) # fmt: skip + context = get_runtime_context() - sys.exit(1) + cluster_handlers = context.get_registry(ClusterHandler) - if args.dump_config: - dump_dataclass(config, sys.stdout) + cluster_resolver = ClusterResolver(cluster_handlers, is_torchrun=is_torchrun()) - sys.exit() + if not args.no_sweep_dir: + sweep_keys = get_sweep_keys(self._extra_sweep_keys) - # If we are not dumping configuration, `--output-dir` is required. - if not args.output_dir: - self._parser.error("the following arguments are required: output_dir") - - self._parser = None - - # Set up cluster-specific environment variables. - if args.cluster == "auto": - env_setter = self._env_setters.get_for_inferred_cluster() + sweep_tagger: SweepTagger = StandardSweepTagger(sweep_keys) else: - try: - env_setter = self._env_setters.get(args.cluster) - except RuntimeError: - log.exception("Recipe is not running on a '{}' cluster.", args.cluster) # fmt: skip + sweep_tagger = NoopSweepTagger() - sys.exit(1) - - try: - env_setter.set_torch_distributed_env() - except RuntimeError: - log.exception("'{}' cluster environment cannot be set.", env_setter.cluster) # fmt: skip + logging_initializer = DistributedLoggingInitializer() - sys.exit(1) + env_bootstrapper = StandardEnvironmentBootstrapper( + cluster_resolver, sweep_tagger, file_system, logging_initializer, dump_yaml + ) - # Determine the output directory. - if args.no_sweep_dir: - output_dir = args.output_dir - else: - tag = self._sweep_tagger(args.preset, preset_config, config) + signal_handler = SystemSignalHandler() - output_dir = args.output_dir.joinpath(tag) + runner = StandardRecipeRunner(self._loader, signal_handler) - # Set up distributed logging. - log_file = output_dir.expanduser().joinpath("logs/rank_{rank}.log").resolve() + return RecipeProgram(config_reader, env_bootstrapper, runner, dump_yaml) - try: - setup_logging(log_file, debug=args.debug) - except RuntimeError: - log.exception("Recipe logging cannot be set up.") - sys.exit(1) +@final +class RecipeProgram: + _config_reader: ConfigReader + _env_bootstrapper: EnvironmentBootstrapper + _runner: RecipeRunner + _yaml_dumper: YamlDumper - log.info("Log files stored under {}.", log_file.parent) + def __init__( + self, + config_reader: ConfigReader, + env_bootstrapper: EnvironmentBootstrapper, + runner: RecipeRunner, + yaml_dumper: YamlDumper, + ) -> None: + self._config_reader = config_reader + self._env_bootstrapper = env_bootstrapper + self._runner = runner + self._yaml_dumper = yaml_dumper + + def run(self, args: Namespace) -> None: + config = self._config_reader.read( + args.preset, args.config_files, args.config_overrides + ) - log_config(config, log, output_dir.joinpath("config.yaml")) + if args.dump_config: + unstructured_config = unstructure(config) - # Load and run the recipe. - recipe = self._loader(config, output_dir) + try: + self._yaml_dumper(unstructured_config, sys.stdout) + except YamlError as ex: + raise SetupError( + "The recipe configuration cannot be dumped to stdout. See the nested exception for details." + ) from ex - # If the recipe is stoppable, use SIGUSR1 as the stop signal. - if isinstance(recipe, Stoppable): + return - def request_stop(signum: int, frame: FrameType) -> None: - log.info("SIGUSR1 received. Requesting recipe to stop.") + if not args.output_dir: + raise MissingOutputDirectoryError("`args.output_dir` must be specified.") + + output_dir = self._env_bootstrapper.run( + args.preset, + config, + args.output_dir, + cluster=args.cluster, + sweep_format=args.sweep_format, + debug=args.debug, + ) - recipe.request_stop() + self._runner.run(config, output_dir) - signal(SIGUSR1, request_stop) - recipe() +class MissingOutputDirectoryError(ValueError): + pass diff --git a/src/fairseq2/recipes/cluster.py b/src/fairseq2/recipes/cluster.py new file mode 100644 index 000000000..7135364fa --- /dev/null +++ b/src/fairseq2/recipes/cluster.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +import subprocess +from abc import ABC, abstractmethod +from collections.abc import Collection +from random import Random +from typing import final + +from typing_extensions import override + +from fairseq2.context import Provider +from fairseq2.gang import get_rank, get_world_size + + +@final +class ClusterResolver: + _handlers: Provider[ClusterHandler] + _is_torchrun: bool + + def __init__(self, handlers: Provider[ClusterHandler], is_torchrun: bool) -> None: + self._handlers = handlers + self._is_torchrun = is_torchrun + + def get(self, name: str) -> ClusterHandler: + if self._is_torchrun or name == "none": + return _NoneClusterHandler() + + if name == "auto": + for _, handler in self._handlers.get_all(): + if handler.supports_current_cluster(): + return handler + + return _NoneClusterHandler() + + try: + return self._handlers.get(name) + except LookupError: + raise UnknownClusterError(name, self.supported_clusters()) from None + + def supported_clusters(self) -> Collection[str]: + return [str(key) for key, _ in self._handlers.get_all()] + + +class UnknownClusterError(LookupError): + cluster: str + supported_clusters: Collection[str] + + def __init__(self, cluster: str, supported_clusters: Collection[str]) -> None: + super().__init__(f"'{cluster}' is not a known cluster.") + + self.cluster = cluster + self.supported_clusters = supported_clusters + + +class ClusterHandler(ABC): + @abstractmethod + def set_torch_distributed_variables(self) -> tuple[int, int]: + """Set environment variables required to initialize ``torch.distributed``.""" + + @abstractmethod + def supports_current_cluster(self) -> bool: + """Return ``True`` if this instance supports the current cluster.""" + + +class ClusterError(Exception): + cluster: str + + def __init__(self, cluster: str, message: str) -> None: + super().__init__(message) + + self.cluster = cluster + + +@final +class SlurmClusterHandler(ClusterHandler): + _job_id: int | None + + def __init__(self) -> None: + self._job_id = None + + @override + def set_torch_distributed_variables(self) -> tuple[int, int]: + job_id = self._ensure_job_id() + + try: + os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"] + os.environ["RANK"] = os.environ["SLURM_PROCID"] + + try: + os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] + except KeyError: + os.environ["LOCAL_WORLD_SIZE"] = "1" + + os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] + + os.environ["MASTER_ADDR"] = self._get_master_addr() + os.environ["MASTER_PORT"] = self._get_master_port(job_id) + + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] + except KeyError as ex: + raise ClusterError( + "slurm", "Slurm job environment variables are not set correctly." + ) from ex + + return get_world_size(), get_rank() + + def _ensure_job_id(self) -> int: + if self._job_id is not None: + return self._job_id + + try: + job_id = os.environ["SLURM_JOB_ID"] + except KeyError: + raise ClusterError( + "slurm", "`SLURM_JOB_ID` environment variable does not exist." + ) from None + + try: + self._job_id = int(job_id) + except ValueError as ex: + raise ClusterError("slurm", "Slurm job ID cannot be parsed.") from ex + + return self._job_id + + @staticmethod + def _get_master_addr() -> str: + nodes = os.environ["SLURM_JOB_NODELIST"] + + result = subprocess.run( + ["scontrol", "show", "hostnames", nodes], capture_output=True, text=True + ) + + if result.returncode == 0: + if node_list := result.stdout.split("\n"): + return node_list[0] + + raise ClusterError( + "slurm", "The hostname or IP address of the Slurm node corresponding to rank 0 cannot be retrieved." # fmt: skip + ) + + @staticmethod + def _get_master_port(job_id: int) -> str: + try: + return os.environ["MASTER_PORT"] + except KeyError: + pass + + return str(Random(job_id).randint(20_000, 60_000)) + + @override + def supports_current_cluster(self) -> bool: + return "SLURM_JOB_ID" in os.environ + + +@final +class _NoneClusterHandler(ClusterHandler): + @override + def set_torch_distributed_variables(self) -> tuple[int, int]: + return get_world_size(), get_rank() + + @override + def supports_current_cluster(self) -> bool: + return True diff --git a/src/fairseq2/recipes/common_metrics.py b/src/fairseq2/recipes/common_metrics.py index c95f20f2a..3ff3ca265 100644 --- a/src/fairseq2/recipes/common_metrics.py +++ b/src/fairseq2/recipes/common_metrics.py @@ -6,8 +6,6 @@ from __future__ import annotations -from typing import Any, Dict, Optional - import torch from torch import Tensor from torcheval.metrics import Throughput @@ -15,25 +13,24 @@ from fairseq2.gang import Gang from fairseq2.generation import Seq2SeqGeneratorOutput, SequenceGeneratorOutput from fairseq2.metrics import MetricBag -from fairseq2.metrics.aggregation import Max, MaxSum, Mean, Sum +from fairseq2.metrics.aggregation import Max, Mean, Sum from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.sequence import SequenceBatch -from fairseq2.typing import override -class TaskMetricBag(MetricBag): - """Holds the metrics of a machine learning task.""" +class BaseMetricBag(MetricBag): + """Holds the base metrics of a machine learning task.""" _train: bool - _num_batches: MaxSum - _num_examples: Sum - _num_elements: Sum - _total_num_examples: Optional[Sum] - _total_num_elements: Optional[Sum] + + num_examples: Sum + num_elements: Sum + total_num_examples: Sum | None + total_num_elements: Sum | None def __init__(self, gang: Gang, train: bool) -> None: """ - :para train: + :param train: If ``True``, indicates that this bag is used in a training task. """ super().__init__(gang) @@ -42,148 +39,113 @@ def __init__(self, gang: Gang, train: bool) -> None: self._train = train - self.register_metric("_num_batches", MaxSum(device=d), persistent=False) - - self.register_metric("_num_examples", Sum(device=d), persistent=False) - self.register_metric("_num_elements", Sum(device=d), persistent=False) + self.register_metric("num_examples", Sum(device=d), persistent=False) + self.register_metric("num_elements", Sum(device=d), persistent=False) if train: - self._total_num_examples = Sum(device=d) - self._total_num_elements = Sum(device=d) - else: - self._total_num_examples = None - self._total_num_elements = None - - @override - def process_metric_values(self, values: Dict[str, Any]) -> None: - super().process_metric_values(values) - - num_batches = values.pop("num_batches") - - num_examples = values["num_examples"] - num_elements = values["num_elements"] - - if num_batches > 0: - values["batch_size"] = num_examples // num_batches + self.total_num_examples = Sum(device=d) + self.total_num_elements = Sum(device=d) else: - values["batch_size"] = 0 + self.total_num_examples = None + self.total_num_elements = None - if num_batches > 0: - values["elements_per_batch"] = num_elements // num_batches - else: - values["elements_per_batch"] = 0 + @property + def train(self) -> bool: + return self._train -class SequenceMetricBag(TaskMetricBag): +class SequenceMetricBag(BaseMetricBag): """Holds the metrics of a sequence model training or evaluation task.""" - _nll_loss: Mean - _num_target_elements: Sum - _total_num_target_elements: Optional[Sum] + nll_loss: Mean + num_target_elements: Sum + total_num_target_elements: Sum | None def __init__(self, gang: Gang, train: bool = True) -> None: super().__init__(gang, train=train) d = gang.device - self.register_metric("_nll_loss", Mean(device=d), persistent=False) + self.register_metric("nll_loss", Mean(device=d), persistent=False) - self.register_metric("_num_target_elements", Sum(device=d), persistent=False) + self.register_metric("num_target_elements", Sum(device=d), persistent=False) if train: - self._total_num_target_elements = Sum(device=d) + self.total_num_target_elements = Sum(device=d) else: - self._total_num_target_elements = None + self.total_num_target_elements = None @torch.inference_mode() def update_nll_loss(self, batch: SequenceBatch, loss: Tensor) -> None: - """Update the NLL loss metric. + """Update the NLL loss metric.""" + n = batch.num_target_elements() - :param batch: - The batch processed by the model. - :param nll_loss: - The loss of ``batch``. - """ - num_target_elements = batch.num_target_elements() - - self._nll_loss.update(loss / num_target_elements, weight=num_target_elements) + self.nll_loss.update(loss.detach() / n, weight=n) @torch.inference_mode() def update_batch_metrics(self, batch: SequenceBatch) -> None: - """Update the batch metrics. - - :param batch: - The batch processed by the model. - """ + """Update the batch metrics.""" num_examples = batch.batch_size - num_elements = batch.num_elements() num_target_elements = batch.num_target_elements() - self._num_batches.update(1) + num_elements = batch.num_elements() - self._num_examples.update(num_examples) - self._num_elements.update(num_elements) + self.num_examples.update(num_examples) + self.num_elements.update(num_elements) - self._num_target_elements.update(num_target_elements) + self.num_target_elements.update(num_target_elements) if self._train: - assert self._total_num_examples is not None - assert self._total_num_elements is not None - assert self._total_num_target_elements is not None + assert self.total_num_examples is not None + assert self.total_num_elements is not None - self._total_num_examples.update(num_examples) - self._total_num_elements.update(num_elements) + assert self.total_num_target_elements is not None - self._total_num_target_elements.update(num_target_elements) + self.total_num_examples.update(num_examples) + self.total_num_elements.update(num_elements) + self.total_num_target_elements.update(num_target_elements) -class Seq2SeqMetricBag(TaskMetricBag): + +class Seq2SeqMetricBag(BaseMetricBag): """Holds the metrics of a sequence-to-sequence model training or evaluation task.""" - _nll_loss: Mean - _num_source_elements: Sum - _num_target_elements: Sum - _total_num_source_elements: Optional[Sum] - _total_num_target_elements: Optional[Sum] + nll_loss: Mean + num_source_elements: Sum + num_target_elements: Sum + total_num_source_elements: Sum | None + total_num_target_elements: Sum | None def __init__(self, gang: Gang, train: bool = True) -> None: super().__init__(gang, train=train) d = gang.device - self.register_metric("_nll_loss", Mean(device=d), persistent=False) + self.register_metric("nll_loss", Mean(device=d), persistent=False) - self.register_metric("_num_source_elements", Sum(device=d), persistent=False) - self.register_metric("_num_target_elements", Sum(device=d), persistent=False) + self.register_metric("num_source_elements", Sum(device=d), persistent=False) + self.register_metric("num_target_elements", Sum(device=d), persistent=False) if train: - self._total_num_source_elements = Sum(device=d) - self._total_num_target_elements = Sum(device=d) + self.total_num_source_elements = Sum(device=d) + self.total_num_target_elements = Sum(device=d) else: - self._total_num_source_elements = None - self._total_num_target_elements = None + self.total_num_source_elements = None + self.total_num_target_elements = None @torch.inference_mode() def update_nll_loss(self, batch: Seq2SeqBatch, loss: Tensor) -> None: - """Update the NLL loss metric. - - :param batch: - The batch processed by the model. - :param nll_loss: - The loss of ``batch``. - """ + """Update the NLL loss metric.""" num_target_elements = batch.num_target_elements() - self._nll_loss.update(loss / num_target_elements, weight=num_target_elements) + self.nll_loss.update( + loss.detach() / num_target_elements, weight=num_target_elements + ) @torch.inference_mode() def update_batch_metrics(self, batch: Seq2SeqBatch) -> None: - """Update the batch metrics. - - :param batch: - The batch processed by the model. - """ + """Update the batch metrics.""" num_examples = batch.batch_size num_source_elements = batch.num_source_elements() @@ -191,59 +153,53 @@ def update_batch_metrics(self, batch: Seq2SeqBatch) -> None: num_elements = num_source_elements + num_target_elements - self._num_batches.update(1) - - self._num_examples.update(num_examples) - self._num_elements.update(num_elements) + self.num_examples.update(num_examples) + self.num_elements.update(num_elements) - self._num_source_elements.update(num_source_elements) - self._num_target_elements.update(num_target_elements) + self.num_source_elements.update(num_source_elements) + self.num_target_elements.update(num_target_elements) if self._train: - assert self._total_num_examples is not None - assert self._total_num_elements is not None + assert self.total_num_examples is not None + assert self.total_num_elements is not None - assert self._total_num_source_elements is not None - assert self._total_num_target_elements is not None + assert self.total_num_source_elements is not None + assert self.total_num_target_elements is not None - self._total_num_examples.update(num_examples) - self._total_num_elements.update(num_elements) + self.total_num_examples.update(num_examples) + self.total_num_elements.update(num_elements) - self._total_num_source_elements.update(num_source_elements) - self._total_num_target_elements.update(num_target_elements) + self.total_num_source_elements.update(num_source_elements) + self.total_num_target_elements.update(num_target_elements) -class SequenceGenerationMetricBag(TaskMetricBag): +class SequenceGenerationMetricBag(BaseMetricBag): """Holds the metrics of a sequence generation task.""" - _generator_prefill_size: Sum - _generator_num_elements: Sum - _generator_elements_per_second: Throughput - _generator_cache_size: Max - _generator_cache_capacity: Max + generator_prefill_size: Sum + generator_num_elements: Sum + generator_elements_per_second: Throughput + generator_cache_size: Max + generator_cache_capacity: Max def __init__(self, gang: Gang) -> None: super().__init__(gang, train=False) d = gang.device - self._generator_prefill_size = Sum(device=d) + self.generator_prefill_size = Sum(device=d) - self._generator_num_elements = Sum(device=d) + self.generator_num_elements = Sum(device=d) - self._generator_elements_per_second = Throughput(device=d) + self.generator_elements_per_second = Throughput(device=d) - self._generator_cache_size = Max(device=d) + self.generator_cache_size = Max(device=d) - self._generator_cache_capacity = Max(device=d) + self.generator_cache_capacity = Max(device=d) @torch.inference_mode() def update_batch_metrics(self, output: SequenceGeneratorOutput) -> None: - """Update the batch metrics. - - :param output: - The :class:`SequenceGenerator` output. - """ + """Update the batch metrics.""" num_examples = len(output.hypotheses) prefill_size = output.counters.prefill_size @@ -252,62 +208,54 @@ def update_batch_metrics(self, output: SequenceGeneratorOutput) -> None: num_elements = prefill_size + num_generated_elements - self._num_batches.update(1) - - self._num_examples.update(num_examples) - self._num_elements.update(num_elements) + self.num_examples.update(num_examples) + self.num_elements.update(num_elements) - self._generator_prefill_size.update(prefill_size) + self.generator_prefill_size.update(prefill_size) - self._generator_num_elements.update(num_generated_elements) + self.generator_num_elements.update(num_generated_elements) - self._generator_elements_per_second.update( + self.generator_elements_per_second.update( num_generated_elements, output.counters.generation_time ) - self._generator_cache_size.update(output.counters.cache_size) + self.generator_cache_size.update(output.counters.cache_size) - self._generator_cache_capacity.update(output.counters.cache_capacity) + self.generator_cache_capacity.update(output.counters.cache_capacity) -class Seq2SeqGenerationMetricBag(TaskMetricBag): +class Seq2SeqGenerationMetricBag(BaseMetricBag): """Holds the metrics of a sequence-to-sequence generation task.""" - _num_source_elements: Sum - _generator_prefill_size: Sum - _generator_num_elements: Sum - _generator_elements_per_second: Throughput - _generator_cache_size: Max - _generator_cache_capacity: Max + num_source_elements: Sum + generator_prefill_size: Sum + generator_num_elements: Sum + generator_elements_per_second: Throughput + generator_cache_size: Max + generator_cache_capacity: Max def __init__(self, gang: Gang) -> None: super().__init__(gang, train=False) d = gang.device - self._num_source_elements = Sum(device=d) + self.num_source_elements = Sum(device=d) - self._generator_prefill_size = Sum(device=d) + self.generator_prefill_size = Sum(device=d) - self._generator_num_elements = Sum(device=d) + self.generator_num_elements = Sum(device=d) - self._generator_elements_per_second = Throughput(device=d) + self.generator_elements_per_second = Throughput(device=d) - self._generator_cache_size = Max(device=d) + self.generator_cache_size = Max(device=d) - self._generator_cache_capacity = Max(device=d) + self.generator_cache_capacity = Max(device=d) @torch.inference_mode() def update_batch_metrics( self, output: Seq2SeqGeneratorOutput, num_source_elements: int ) -> None: - """Update the batch metrics. - - :param output: - The :class:`Seq2SeqGenerator` output. - :param num_source_elements: - The number of source elements processed by the model. - """ + """Update the batch metrics.""" num_examples = len(output.hypotheses) prefill_size = output.counters.prefill_size @@ -316,37 +264,53 @@ def update_batch_metrics( num_elements = num_source_elements + prefill_size + num_generated_elements - self._num_batches.update(1) + self.num_examples.update(num_examples) + self.num_elements.update(num_elements) - self._num_examples.update(num_examples) - self._num_elements.update(num_elements) + self.num_source_elements.update(num_source_elements) - self._num_source_elements.update(num_source_elements) + self.generator_prefill_size.update(prefill_size) - self._generator_prefill_size.update(prefill_size) + self.generator_num_elements.update(num_generated_elements) - self._generator_num_elements.update(num_generated_elements) - - self._generator_elements_per_second.update( + self.generator_elements_per_second.update( num_generated_elements, output.counters.generation_time ) - self._generator_cache_size.update(output.counters.cache_size) + self.generator_cache_size.update(output.counters.cache_size) + + self.generator_cache_capacity.update(output.counters.cache_capacity) + - self._generator_cache_capacity.update(output.counters.cache_capacity) +def extend_batch_metrics( + metric_values: dict[str, object], num_batches: int, elapsed_time: float +) -> None: + def get_value(name: str) -> int | float | Tensor | None: + try: + value = metric_values[name] + except KeyError: + return None + if not isinstance(value, (int, float, Tensor)): + return None -def set_throughput_value(metric_values: Dict[str, Any], elapsed_time: float) -> None: - """Set the throughput value in ``metric_values``.""" - try: - num_elements = metric_values["num_elements"] - except KeyError: - return + return value - if not isinstance(num_elements, (int, float, Tensor)): - return + num_examples = get_value("num_examples") + if num_examples is not None: + if num_batches > 0: + metric_values["batch_size"] = num_examples // num_batches + else: + metric_values["batch_size"] = 0 - if elapsed_time == 0.0: - metric_values["elements_per_second"] = 0.0 - else: - metric_values["elements_per_second"] = num_elements / elapsed_time + num_elements = get_value("num_elements") + if num_elements is not None: + if num_batches > 0: + metric_values["elements_per_batch"] = num_elements // num_batches + else: + metric_values["elements_per_batch"] = 0 + + if elapsed_time > 0.0: + metric_values["elements_per_second"] = num_elements / elapsed_time + else: + metric_values["elements_per_second"] = 0.0 diff --git a/src/fairseq2/recipes/early_stopper.py b/src/fairseq2/recipes/early_stopper.py new file mode 100644 index 000000000..64612813a --- /dev/null +++ b/src/fairseq2/recipes/early_stopper.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import final + +from typing_extensions import override + + +class EarlyStopper(ABC): + """Stops training when an implementation-specific condition is not met.""" + + @abstractmethod + def should_stop(self, step_nr: int, score: float) -> bool: + """ + :param step_nr: The number of the current training step. + :para score: The validation score of the current training step. + + :returns: ``True`` if the training should be stopped; otherwise, ``False``. + """ + + +@final +class NoopEarlyStopper(EarlyStopper): + @override + def should_stop(self, step_nr: int, score: float) -> bool: + return False diff --git a/src/fairseq2/recipes/evaluator.py b/src/fairseq2/recipes/evaluator.py index 5bc060f2a..3c9ed11e6 100644 --- a/src/fairseq2/recipes/evaluator.py +++ b/src/fairseq2/recipes/evaluator.py @@ -7,15 +7,18 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence +from contextlib import AbstractContextManager, nullcontext from itertools import count from pathlib import Path -from typing import Generic, List, Optional, Sequence, TypeVar, final +from typing import Generic, TypeVar, final import torch from torch.nn import Module +from typing_extensions import override from fairseq2.datasets import DataReader -from fairseq2.gang import FakeGang, Gang, all_sum +from fairseq2.gang import FakeGang, Gang from fairseq2.logging import get_log_writer from fairseq2.metrics import ( JsonFileMetricRecorder, @@ -25,9 +28,9 @@ TensorBoardRecorder, record_metrics, ) -from fairseq2.recipes.common_metrics import set_throughput_value -from fairseq2.recipes.utils.cli import create_rich_progress -from fairseq2.typing import CPU, override +from fairseq2.recipes.common_metrics import extend_batch_metrics +from fairseq2.recipes.utils.rich import create_rich_progress +from fairseq2.typing import CPU, DataType from fairseq2.utils.profiler import Stopwatch from fairseq2.utils.rng import RngBag @@ -57,7 +60,7 @@ def model(self) -> Module: @property @abstractmethod - def display_name(self) -> Optional[str]: + def display_name(self) -> str | None: """The display name of the unit for reporting purposes.""" @property @@ -70,9 +73,9 @@ class AbstractEvalUnit(EvalUnit[BatchT]): """Provides a skeletal implementation of :class:`EvalUnit`.""" _model: Module - _display_name: Optional[str] + _display_name: str | None - def __init__(self, model: Module, *, display_name: Optional[str] = None) -> None: + def __init__(self, model: Module, *, display_name: str | None = None) -> None: self._model = model self._display_name = display_name @@ -89,7 +92,7 @@ def model(self) -> Module: @final @property @override - def display_name(self) -> Optional[str]: + def display_name(self) -> str | None: return self._display_name @@ -102,7 +105,9 @@ class Evaluator(Generic[BatchT]): _root_gang: Gang _dp_gang: Gang _tp_gang: Gang - _metric_recorders: List[MetricRecorder] + _dtype: DataType + _amp: bool + _metric_recorders: list[MetricRecorder] _seed: int _wall_watch: Stopwatch _run: bool @@ -114,10 +119,12 @@ def __init__( data_readers: Sequence[DataReader[BatchT]], root_gang: Gang, wall_watch: Stopwatch, - dp_gang: Optional[Gang] = None, - tp_gang: Optional[Gang] = None, - tb_dir: Optional[Path] = None, - metrics_dir: Optional[Path] = None, + dp_gang: Gang | None = None, + tp_gang: Gang | None = None, + dtype: DataType = torch.float32, + amp: bool = False, + tb_dir: Path | None = None, + metrics_dir: Path | None = None, seed: int = 2, ) -> None: """ @@ -133,6 +140,10 @@ def __init__( The data parallel gang. If ``None``, ``root_gang`` will be used. :param tp_gang: The tensor parallel gang. Only required for tensor parallel models. + :param dtype: + The data type of the model. + :param amp: + If ``True``, enables ``torch.amp``. :param tb_dir: The TensorBoard log directory to dump metrics. :param metrics_dir: @@ -166,6 +177,10 @@ def __init__( f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead." ) + self._dtype = dtype + + self._amp = amp + if root_gang.rank == 0: self._metric_recorders = [LogMetricRecorder(log)] @@ -221,6 +236,8 @@ def _evaluate_unit( unit.model.eval() + num_effective_batches = 0 + with create_rich_progress() as progress: task = progress.add_task("eval", total=None) @@ -232,22 +249,25 @@ def _evaluate_unit( try: batches = next(data_reader) except StopIteration: - batches = [] + break for batch in batches: - unit(batch) + with self._maybe_autocast(): + unit(batch) - if self._is_eod(batches): - break + num_effective_batches += 1 - self._publish_metrics(unit, watch.get_elapsed_time()) + self._publish_metrics(unit, num_effective_batches, watch.get_elapsed_time()) - def _is_eod(self, batches: List[BatchT]) -> bool: - total_num_batches = all_sum(self._dp_gang, len(batches)) + def _maybe_autocast(self) -> AbstractContextManager[None]: + if self._dtype == torch.float32 or not self._amp: + return nullcontext() - return bool(total_num_batches == 0) + return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) - def _publish_metrics(self, unit: EvalUnit[BatchT], elapsed_time: float) -> None: + def _publish_metrics( + self, unit: EvalUnit[BatchT], num_batches: int, elapsed_time: float + ) -> None: log.debug("Syncing metrics.") if self._tp_gang.rank == 0: @@ -260,9 +280,12 @@ def _publish_metrics(self, unit: EvalUnit[BatchT], elapsed_time: float) -> None: if self._root_gang.rank != 0: return - assert values is not None + if values is None: + raise RuntimeError( + "The synchronized metric values are `None`. Please file a bug report." + ) - set_throughput_value(values, elapsed_time) + extend_batch_metrics(values, num_batches, elapsed_time) values["elapsed_time"] = elapsed_time diff --git a/src/fairseq2/recipes/generator.py b/src/fairseq2/recipes/generator.py index e827e2f8c..bb9355b16 100644 --- a/src/fairseq2/recipes/generator.py +++ b/src/fairseq2/recipes/generator.py @@ -7,15 +7,17 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import AbstractContextManager, nullcontext from itertools import count from pathlib import Path -from typing import Generic, List, Optional, TypeVar, final +from typing import Generic, TypeVar, final import torch from torch.nn import Module +from typing_extensions import override from fairseq2.datasets import DataReader -from fairseq2.gang import FakeGang, Gang, all_sum +from fairseq2.gang import FakeGang, Gang from fairseq2.logging import get_log_writer from fairseq2.metrics import ( JsonFileMetricRecorder, @@ -24,9 +26,9 @@ MetricRecorder, record_metrics, ) -from fairseq2.recipes.common_metrics import set_throughput_value -from fairseq2.recipes.utils.cli import create_rich_progress -from fairseq2.typing import CPU, override +from fairseq2.recipes.common_metrics import extend_batch_metrics +from fairseq2.recipes.utils.rich import create_rich_progress +from fairseq2.typing import CPU, DataType from fairseq2.utils.profiler import Stopwatch from fairseq2.utils.rng import RngBag @@ -78,7 +80,9 @@ class Generator(Generic[BatchT]): _root_gang: Gang _dp_gang: Gang _tp_gang: Gang - _metric_recorders: List[MetricRecorder] + _dtype: DataType + _amp: bool + _metric_recorders: list[MetricRecorder] _seed: int _wall_watch: Stopwatch _run: bool @@ -90,9 +94,11 @@ def __init__( data_reader: DataReader[BatchT], root_gang: Gang, wall_watch: Stopwatch, - dp_gang: Optional[Gang] = None, - tp_gang: Optional[Gang] = None, - metrics_dir: Optional[Path] = None, + dp_gang: Gang | None = None, + tp_gang: Gang | None = None, + dtype: DataType = torch.float32, + amp: bool = False, + metrics_dir: Path | None = None, seed: int = 2, ) -> None: """ @@ -108,6 +114,10 @@ def __init__( The data parallel gang. If ``None``, ``gang`` will be used. :param tp_gang: The tensor parallel gang. Only required for tensor parallel models. + :param dtype: + The data type of the model. + :param amp: + If ``True``, enables ``torch.amp``. :param metrics_dir: The directory to dump metrics. :param seed: @@ -134,6 +144,10 @@ def __init__( f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead." ) + self._dtype = dtype + + self._amp = amp + if root_gang.rank == 0: self._metric_recorders = [LogMetricRecorder(log)] @@ -177,6 +191,8 @@ def _do_run(self) -> None: self._unit.model.eval() + num_effective_batches = 0 + with create_rich_progress() as progress: task = progress.add_task("generate", total=None) @@ -188,22 +204,23 @@ def _do_run(self) -> None: try: batches = next(self._data_reader) except StopIteration: - batches = [] + break for batch in batches: - self._unit(batch) + with self._maybe_autocast(): + self._unit(batch) - if self._is_eod(batches): - break + num_effective_batches += 1 - self._publish_metrics(watch.get_elapsed_time()) + self._publish_metrics(num_effective_batches, watch.get_elapsed_time()) - def _is_eod(self, batches: List[BatchT]) -> bool: - total_num_batches = all_sum(self._dp_gang, len(batches)) + def _maybe_autocast(self) -> AbstractContextManager[None]: + if self._dtype == torch.float32 or not self._amp: + return nullcontext() - return bool(total_num_batches == 0) + return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) - def _publish_metrics(self, elapsed_time: float) -> None: + def _publish_metrics(self, num_batches: int, elapsed_time: float) -> None: log.debug("Syncing metrics.") if self._tp_gang.rank != 0: @@ -214,9 +231,12 @@ def _publish_metrics(self, elapsed_time: float) -> None: if self._root_gang.rank != 0: return - assert values is not None + if values is None: + raise RuntimeError( + "The synchronized metric values are `None`. Please file a bug report." + ) - set_throughput_value(values, elapsed_time) + extend_batch_metrics(values, num_batches, elapsed_time) values["elapsed_time"] = elapsed_time diff --git a/src/fairseq2/recipes/hg/__init__.py b/src/fairseq2/recipes/hg/__init__.py new file mode 100644 index 000000000..36d4ba38a --- /dev/null +++ b/src/fairseq2/recipes/hg/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +try: + import datasets # type: ignore[attr-defined,import-untyped,import-not-found] + + _has_hg_datasets = True +except ImportError: + _has_hg_datasets = False + + +try: + import evaluate # type: ignore[attr-defined,import-untyped,import-not-found] + + _has_hg_evaluate = True +except ImportError: + _has_hg_evaluate = False + + +from fairseq2.recipes.cli import Cli, RecipeCommandHandler + + +def _setup_hg_cli(cli: Cli) -> None: + if not _has_hg_datasets or not _has_hg_evaluate: + return + + group = cli.add_group("hg", help="Hugging Face recipes") + + from fairseq2.recipes.hg.asr_eval import ( + asr_eval_presets, + load_wav2vec2_asr_evaluator, + ) + + handler = RecipeCommandHandler( + load_wav2vec2_asr_evaluator, + preset_configs=asr_eval_presets, + default_preset="librispeech_asr", + ) + + group.add_command( + "wav2vec2_asr", + handler, + help="evaluate a wav2vec 2.0 ASR model on a downstream benchmark (default: librispeech_asr)", + ) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py new file mode 100644 index 000000000..fbcab14a8 --- /dev/null +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import Any, cast + +import torch +from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] + Dataset, + load_dataset, +) + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.data.data_pipeline import SequenceData +from fairseq2.data.text import TextTokenizer, load_text_tokenizer +from fairseq2.datasets import StaticBatching +from fairseq2.logging import get_log_writer +from fairseq2.models.seq2seq import Seq2SeqBatch +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model +from fairseq2.nn.padding import get_seqs_and_padding_mask +from fairseq2.recipes.hg.dataset import Example, create_hf_reader +from fairseq2.recipes.hg.evaluator import HFEvaluator +from fairseq2.recipes.utils.setup import setup_root_gang +from fairseq2.typing import META, DataType +from fairseq2.utils.profiler import Stopwatch + +log = get_log_writer(__name__) + + +@dataclass(kw_only=True) +class AsrEvalConfig: + """Holds the configuration of a ASR evaluation recipe.""" + + # Data + dataset_name: str + """The HF dataset to evaluate with.""" + + # Model + model_name: str + """The name of the model to evaluate.""" + + # converter: Callable[[Example], Seq2SeqBatch] + # """The converter function to convert collated data into Seq2SeqBatch""" + + tokenizer_name: str = "librispeech_asr" + """The tokenizer to use.""" + + split: str = "test" + """The name of the dataset split to evaluate with.""" + + min_audio_len: int = 1 + """The minimum audio sequence length.""" + + max_audio_len: int = 800_000 + """The maximum audio sequence length.""" + + max_num_elements: int = 3_200_000 + """The maximum number of elements per batch.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" + + max_samples: int | None = None + """Maximum number of samples from the dataset to be evaluated. Used + e.g. for debugging. Default is None, meaning all samples will be evaluated""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + checkpoint_dir: Path | None = None + """The checkpoint directory containing models saved by a :class:`FileCheckpointManager`.""" + + dtype: DataType = torch.float16 + """The data type of the model.""" + + +asr_eval_presets = ConfigRegistry[AsrEvalConfig]() + +asr_eval_preset = asr_eval_presets.decorator + + +@asr_eval_preset("librispeech_asr") +def _librispeech_asr_config() -> AsrEvalConfig: + return AsrEvalConfig( + dataset_name="librispeech_asr", + model_name="wav2vec2_asr_base_10h", + split="test.other", + # converter=librispeech_asr_to_batch, + ) + + +def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: + """ + Converts a collated batch of examples into a Seq2SeqBatch. + + Args: + examples (dict): A dictionary containing "audio" and "text" keys. + + Returns: + Seq2SeqBatch: A batch of audio and text sequences. + """ + source_data = cast(SequenceData, examples["audio"]) + target_data = cast(SequenceData, examples["text"]) + + source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) + + return Seq2SeqBatch( + source_seqs, + source_padding_mask, + target_seqs, + target_padding_mask, + examples, + ) + + +@lru_cache(maxsize=None) +def get_cached_tokenizer(tokenizer_name: str) -> TextTokenizer: + return load_text_tokenizer(tokenizer_name) + + +def _preprocess_example( + example: Example, tokenizer_name: str, device: torch.device +) -> Example: + """ + Preprocesses an individual example by converting the audio array to a PyTorch tensor + and encoding the text. + + Args: + example (dict): A dictionary containing "audio" and "text" keys. + tokenizer_name (str): The name of the tokenizer to use. + device (torch.device): The device to store the tensors. + + Returns: + dict: A dictionary with "audio" and "text" as PyTorch tensors. + """ + tokenizer = get_cached_tokenizer(tokenizer_name) + encoder = tokenizer.create_encoder(device=device) + audio_tensor = ( + torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device) + ) + text_tensor = encoder(example["text"].lower()).to(device) + return {"audio": audio_tensor, "text": text_tensor} + + +def seq2seq_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: + return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( + batch.target_seqs, batch.target_padding_mask + ) + + +def postprocesser( + outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer +) -> tuple[list[str], list[str]]: + decoder = tokenizer.create_decoder() + pad_idx = tokenizer.vocab_info.pad_idx + + hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) + predictions = [decoder(item) for item in hypotheses] + references = [decoder(item) for item in targets.seqs.to(torch.int32)] + + return predictions, references + + +def load_wav2vec2_asr_evaluator( + config: AsrEvalConfig, output_dir: Path +) -> HFEvaluator[Seq2SeqBatch]: + """ + Load the evaluator used for downstream evaluation of the model + in a downstream dataset and report BLEU scores + + Args: + config (HFEvalConfig): The configuration for the evaluation. + output_dir (Path): The output directory to store the evaluation results. + + Returns: + HFEvaluator: Evaluation process results. + """ + if not isinstance(config, AsrEvalConfig): + raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") + + iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) + # Load a subset of the dataset if max_samples is set + ds = Dataset.from_generator( + lambda: itertools.islice(iterable_ds, 0, config.max_samples), + features=iterable_ds.features, + ) + + gang = setup_root_gang(log) + + if gang.rank == 0: + init_device = gang.device + else: + init_device = META + + ds = ds.map(lambda x: _preprocess_example(x, config.tokenizer_name, init_device)) + format = { + "type": "torch", + "format_kwargs": {"dtype": torch.float16, "device": init_device}, + } + ds.set_format(**format, columns=["audio", "text"]) + + tokenizer = get_cached_tokenizer(config.tokenizer_name) + + pipeline_reader = create_hf_reader( + dataset=ds, + gang=gang, + converter=_librispeech_asr_to_batch, + batching=StaticBatching(config.max_num_elements), + num_prefetch=config.num_prefetch, + pad_value=tokenizer.vocab_info.pad_idx, + max_seq_len=config.max_audio_len, + ) + + model = load_wav2vec2_asr_model( + config.model_name, device=init_device, dtype=config.dtype + ) + + wall_watch = Stopwatch(start=True, device=init_device) + + return HFEvaluator[Seq2SeqBatch]( + model=model, + metrics=["bleu"], + gang=gang, + data_reader=pipeline_reader, + wall_watch=wall_watch, + preprocessor=seq2seq_preprocessor, + postprocessor=lambda x, y: postprocesser(x, y, tokenizer), + ) diff --git a/src/fairseq2/datasets/huggingface.py b/src/fairseq2/recipes/hg/dataset.py similarity index 89% rename from src/fairseq2/datasets/huggingface.py rename to src/fairseq2/recipes/hg/dataset.py index cf5e3fc2f..5ac62ee9b 100644 --- a/src/fairseq2/datasets/huggingface.py +++ b/src/fairseq2/recipes/hg/dataset.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, Optional, Union +from __future__ import annotations + +from collections.abc import Callable +from typing import Any from fairseq2.data.data_pipeline import Collater, create_bucket_sizes, read_sequence -from fairseq2.datasets.batching import LengthBatching, StaticBatching +from fairseq2.datasets import Batching, LengthBatching, StaticBatching from fairseq2.datasets.data_reader import BatchT, DataPipelineReader from fairseq2.gang import Gang @@ -21,23 +24,22 @@ else: has_datasets = True - -Example = Dict[str, Any] +Example = dict[str, Any] def create_hf_reader( - dataset: Dataset, + dataset: "Dataset", gang: Gang, converter: Callable[[Example], BatchT], *, - batching: Optional[Union[StaticBatching, LengthBatching]] = None, - max_seq_len: Optional[int] = None, + batching: Batching | None = None, + max_seq_len: int | None = None, drop_remainder: bool = False, min_seq_len: int = 0, - seq_len_col: Optional[str] = None, + seq_len_col: str | None = None, num_accumulate: int = 1, num_prefetch: int = 1, - pad_value: Optional[int] = None, + pad_value: int | None = None, **extra: Any, ) -> DataPipelineReader[BatchT]: """ @@ -72,11 +74,7 @@ def create_hf_reader( training. :param num_prefetch: The number of batches to prefetch in background. - :param extras: - The extra parameters specific to the dataset - implementation. """ - if not has_datasets: raise ModuleNotFoundError( "`datasets` is required but not found. Please install it with `pip install datasets`." @@ -107,17 +105,8 @@ def create_hf_reader( raise ValueError( "`max_seq_len` is required if batching strategy is specified" ) - if isinstance(batching, StaticBatching): - if seq_len_col: - def skip(example: Example) -> bool: - _len = len(example[seq_len_col]) - return _len >= min_seq_len and _len <= max_seq_len - - builder.filter(skip) - - builder = builder.bucket(batching.batch_size, drop_remainder=drop_remainder) - else: + if isinstance(batching, LengthBatching): bucket_sizes = create_bucket_sizes( max_seq_len=max_seq_len, min_seq_len=min_seq_len, @@ -134,6 +123,18 @@ def skip(example: Example) -> bool: skip_above_max_examples=True, drop_remainder=drop_remainder, ) + elif isinstance(batching, StaticBatching): + if seq_len_col: + + def skip(example: Example) -> bool: + _len = len(example[seq_len_col]) + return _len >= min_seq_len and _len <= max_seq_len + + builder.filter(skip) + + builder = builder.bucket(batching.batch_size, drop_remainder=drop_remainder) + else: + raise RuntimeError(f"`{batching}` is not supported.") # collate to python dict builder.map(Collater(pad_value=pad_value)) @@ -146,6 +147,7 @@ def skip(example: Example) -> bool: pipeline = builder.map(converter).and_return() return DataPipelineReader[BatchT]( + "hg", pipeline, gang, num_accumulate=num_accumulate, diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py new file mode 100644 index 000000000..5f207e20c --- /dev/null +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import importlib +from collections.abc import Callable +from itertools import count +from pathlib import Path +from typing import Any, Generic, TypeVar, final + +from fairseq2.datasets import DataReader +from fairseq2.gang import FakeGang, Gang +from fairseq2.logging import get_log_writer +from fairseq2.metrics import ( + LogMetricRecorder, + MetricRecorder, + TensorBoardRecorder, + record_metrics, +) +from fairseq2.models.model import Model +from fairseq2.models.sequence import SequenceBatch +from fairseq2.recipes.utils.rich import create_rich_progress +from fairseq2.utils.profiler import Stopwatch + +log = get_log_writer(__name__) + + +BatchT = TypeVar("BatchT") + + +@final +class HFEvaluator(Generic[BatchT]): + """Evaluate a machine learning model with HuggingFace's evaluate.Metric library""" + + _model: Model + _preprocessor: Callable[[BatchT], tuple[SequenceBatch, SequenceBatch]] + _postprocessor: Callable[[Any, SequenceBatch], tuple[list[str], list[str]]] + _root_gang: Gang + _dp_gang: Gang + _tp_gang: Gang + _data_reader: DataReader[BatchT] + _metric_recorders: list[MetricRecorder] + _wall_watch: Stopwatch + _elapsed_time: float + _run: bool + + def __init__( + self, + model: Model, + metrics: list[str], + gang: Gang, + data_reader: DataReader[BatchT], + wall_watch: Stopwatch, + preprocessor: Callable[[BatchT], tuple[SequenceBatch, SequenceBatch]], + postprocessor: Callable[[Any, SequenceBatch], tuple[list[str], list[str]]], + dp_gang: Gang | None = None, + tp_gang: Gang | None = None, + tb_dir: Path | None = None, + ) -> None: + """ + :param model: + The fairseq2 machine learning model to be evaluate + :param metrics: + The list of metric names implemented in HuggingFace.evaluate + :param gang: + The gang to use for distributed evaluation. + :param data_reader: + The data reader of the eval split. + :param wall_watch: + The stopwatch to track process wall-time. + :param preprocessor: + The preprocessor to convert the batch into inputs and targets SequenceBatch objects. + :param postprocessor: + The postprocessor to convert the model outputs and target sequences into predictions and references. + :param dp_gang: + The data parallel gang. If ``None``, ``gang`` will be used. + :param tp_gang: + The tensor parallel gang. Only required for tensor parallel models. + :param tb_dir: + The TensorBoard log directory to dump metrics. + """ + try: + evaluate = importlib.import_module("evaluate") + except ImportError as exc: + raise ImportError( + "HFMetric requires the library `evaluate`, for instance via `pip install evaluate`" + ) from exc + + self._model = model + + self._root_gang = gang + + if dp_gang is not None and tp_gang is not None: + self._dp_gang = dp_gang + self._tp_gang = tp_gang + elif dp_gang is None and tp_gang is None: + self._dp_gang = gang + self._tp_gang = FakeGang(device=gang.device) + else: + raise ValueError("`dp_gang` and `tp_gang` must be both specified.") + + self._data_reader = data_reader + + self._metrics = evaluate.combine(metrics) + + self._preprocessor = preprocessor + + self._postprocessor = postprocessor + + if self._tp_gang.rank == 0 and self._dp_gang.rank == 0: + self._metric_recorders = [LogMetricRecorder(log)] + + if tb_dir is not None: + self._metric_recorders.append(TensorBoardRecorder(tb_dir)) + else: + self._metric_recorders = [] + + self._wall_watch = wall_watch + + self._elapsed_time = 0.0 + + self._run = False + + def __call__(self) -> None: + if self._run: + raise RuntimeError("The evaluator can only be run once.") + + self._run = True + + log.info("Running evaluation on {} device(s).", self._root_gang.size) + + try: + self._do_run() + self._publish_evaluation_metrics() + except KeyboardInterrupt: + log.info("Evaluation terminated") + + raise + + elapsed_time = self._wall_watch.get_elapsed_time() + + log.info("Evaluation complete in {:,} seconds", int(elapsed_time)) + + def _do_run(self) -> None: + with create_rich_progress() as progress: + eval_task = progress.add_task("eval", total=None) + + watch = Stopwatch(start=True, device=self._root_gang.device) + + for step_nr in count(start=1): + self._step_nr = step_nr + + try: + batches = next(self._data_reader) + except StopIteration: + break + + progress.update(eval_task, refresh=True, advance=1) + + log.debug("Running step {}.", step_nr) + + for batch in batches: + inputs, targets = self._preprocessor(batch) + outputs = self._model(inputs) + predictions, references = self._postprocessor(outputs, targets) + + self._metrics.add_batch( + predictions=predictions, references=references + ) + + self._root_gang.barrier() + + self._elapsed_time = watch.get_elapsed_time() + + def _publish_evaluation_metrics(self) -> None: + """ + publish evaluation metrics to log and TensorBoard folder. + Note that contrast to fairseq2.metrics, which rely on torcheval, + + HuggingFace's evaluate has an internal support for distributed + evaluation (see + https://huggingface.co/docs/evaluate/en/a_quick_tour#distributed-evaluation), + so we do not to call explicitly sync_and_compute_metrics(), but simply + evaluate.compute() + """ + values = self._metrics.compute() + + # In all other rank, values will be zero + if self._tp_gang.rank != 0 or self._dp_gang.rank != 0: + return + + assert values is not None + + values["elapsed_time"] = self._elapsed_time + + values["wall_time"] = self._wall_watch.get_elapsed_time() + + record_metrics(self._metric_recorders, "eval", values) diff --git a/src/fairseq2/recipes/llama/__init__.py b/src/fairseq2/recipes/llama/__init__.py index 5a45d11b7..878318a94 100644 --- a/src/fairseq2/recipes/llama/__init__.py +++ b/src/fairseq2/recipes/llama/__init__.py @@ -8,6 +8,7 @@ from fairseq2.recipes.cli import Cli from fairseq2.recipes.llama.convert_checkpoint import ConvertCheckpointCommandHandler +from fairseq2.recipes.llama.write_hf_config import WriteHfConfigCommandHandler def _setup_llama_cli(cli: Cli) -> None: @@ -18,3 +19,9 @@ def _setup_llama_cli(cli: Cli) -> None: handler=ConvertCheckpointCommandHandler(), help="convert fairseq2 LLaMA checkpoints to reference checkpoints", ) + + group.add_command( + name="write_hf_config", + handler=WriteHfConfigCommandHandler(), + help="write fairseq2 LLaMA config in Huggingface config format", + ) diff --git a/src/fairseq2/recipes/llama/convert_checkpoint.py b/src/fairseq2/recipes/llama/convert_checkpoint.py index c4ab7003a..5df21d2e6 100644 --- a/src/fairseq2/recipes/llama/convert_checkpoint.py +++ b/src/fairseq2/recipes/llama/convert_checkpoint.py @@ -15,13 +15,18 @@ from typing import final from warnings import catch_warnings -from fairseq2.console import get_error_console +from typing_extensions import override + +from fairseq2.assets import default_asset_store from fairseq2.logging import get_log_writer from fairseq2.models.llama import load_llama_config -from fairseq2.models.llama.integ import convert_to_reference_checkpoint +from fairseq2.models.llama.integ import ( + convert_to_reference_checkpoint, + get_ffn_dim_multipliers, +) from fairseq2.recipes.cli import CliCommandHandler -from fairseq2.typing import override -from fairseq2.utils.file import dump_tensors, load_tensors +from fairseq2.recipes.utils.rich import get_error_console +from fairseq2.utils.file import dump_torch_tensors, load_torch_tensors log = get_log_writer(__name__) @@ -33,9 +38,9 @@ class ConvertCheckpointCommandHandler(CliCommandHandler): @override def init_parser(self, parser: ArgumentParser) -> None: parser.add_argument( - "--arch", + "--model", metavar="ARCH_NAME", - help="architecture name to generate params.json", + help="model name to fetch architecture to generate params.json", ) parser.add_argument( @@ -51,7 +56,7 @@ def init_parser(self, parser: ArgumentParser) -> None: ) @override - def __call__(self, args: Namespace) -> None: + def run(self, parser: ArgumentParser, args: Namespace) -> int: if not args.input_dir.exists() or not args.input_dir.is_dir(): log.error("`input_dir` must be a directory.") @@ -62,8 +67,12 @@ def __call__(self, args: Namespace) -> None: sys.exit(1) - if args.arch: - model_config = load_llama_config(args.arch) + arch = ( + default_asset_store.retrieve_card(args.model).field("model_arch").as_(str) + ) + + if arch: + model_config = load_llama_config(args.model) else: model_config = None @@ -105,7 +114,7 @@ def __call__(self, args: Namespace) -> None: with catch_warnings(): warnings.simplefilter("ignore") - checkpoint = load_tensors(input_file, restrict=True) + checkpoint = load_torch_tensors(input_file, restrict=True) except RuntimeError: log.exception( "Checkpoint file {} cannot be loaded.", input_file.name @@ -113,8 +122,8 @@ def __call__(self, args: Namespace) -> None: sys.exit(1) - if "model" not in checkpoint: - log.error("Checkpoint file {} does not contain a 'model' entry.", input_file.name) # fmt: skip + if all(key not in checkpoint for key in ["model_key", "model"]): + log.error("Checkpoint file {} does not contain a 'model_key' nor 'model' entry.", input_file.name) # fmt: skip sys.exit(1) @@ -125,11 +134,9 @@ def __call__(self, args: Namespace) -> None: ref_state_dict = convert_to_reference_checkpoint(checkpoint) try: - dump_tensors(ref_state_dict, output_file) + dump_torch_tensors(ref_state_dict, output_file) except RuntimeError: - log.exception( - "Checkpoint file {} cannot be saved.", output_file.name - ) + log.exception("Checkpoint file {} cannot be saved.", output_file.name) # fmt: skip sys.exit(1) @@ -151,8 +158,10 @@ def __call__(self, args: Namespace) -> None: if model_config.num_attn_heads != model_config.num_key_value_heads: params["model"]["n_kv_heads"] = model_config.num_key_value_heads - if args.arch == "llama2_70b" or args.arch.startswith("llama3"): - params["model"]["ffn_dim_multiplier"] = 1.3 + ffn_dim_multiplier = get_ffn_dim_multipliers(arch) + + if ffn_dim_multiplier != 1.0: + params["model"]["ffn_dim_multiplier"] = ffn_dim_multiplier try: with args.output_dir.joinpath("params.json").open("w") as fp: @@ -162,4 +171,6 @@ def __call__(self, args: Namespace) -> None: sys.exit(1) - log.info("params.json generated for {}.", args.arch) + log.info("params.json generated for {}.", args.model) + + return 0 diff --git a/src/fairseq2/recipes/llama/write_hf_config.py b/src/fairseq2/recipes/llama/write_hf_config.py new file mode 100644 index 000000000..2cb606c99 --- /dev/null +++ b/src/fairseq2/recipes/llama/write_hf_config.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import json +import sys +from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import final + +from typing_extensions import override + +from fairseq2.assets import default_asset_store +from fairseq2.logging import get_log_writer +from fairseq2.models.llama import load_llama_config +from fairseq2.models.llama.integ import convert_to_huggingface_config +from fairseq2.recipes.cli import CliCommandHandler + +log = get_log_writer(__name__) + + +@final +class WriteHfConfigCommandHandler(CliCommandHandler): + """Writes fairseq2 LLaMA config files in Huggingface format.""" + + @override + def init_parser(self, parser: ArgumentParser) -> None: + parser.add_argument( + "--model", + metavar="ARCH_NAME", + help="model name to fetch architecture to generate config.json", + ) + + parser.add_argument( + "output_dir", + type=Path, + help="output directory to store reference checkpoint", + ) + + @override + def run(self, parser: ArgumentParser, args: Namespace) -> int: + arch = ( + default_asset_store.retrieve_card(args.model).field("model_arch").as_(str) + ) + + if arch: + model_config = load_llama_config(args.model) + else: + model_config = None + + if model_config is None: + log.error("Config could not be retrieved for model {}", args.model) + + sys.exit(1) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Convert and write the config + log.info("Writing config...") + + config = convert_to_huggingface_config(arch, model_config) + + json_file = args.output_dir.joinpath("config.json") + + try: + with json_file.open("w") as fp: + json.dump(config, fp, indent=2, sort_keys=True) + except OSError as ex: + raise RuntimeError( + f"The file {json_file} cannot be saved. See the nested exception for details." + ) from ex + + log.info("Config converted and saved in {}", json_file) + + return 0 diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index 0af279e28..568c3d1c5 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -8,6 +8,7 @@ from fairseq2.recipes.cli import Cli, RecipeCommandHandler from fairseq2.recipes.lm.chatbot import ChatbotCommandHandler +from fairseq2.recipes.lm.eval_nll import load_nll_evaluator, nll_eval_presets from fairseq2.recipes.lm.instruction_finetune import ( instruction_finetune_presets, load_instruction_finetuner, @@ -33,7 +34,7 @@ def _setup_lm_cli(cli: Cli) -> None: instruction_finetune_handler = RecipeCommandHandler( loader=load_instruction_finetuner, preset_configs=instruction_finetune_presets, - default_preset="llama3_8b_instruct", + default_preset="llama3_1_instruct", ) group.add_command( @@ -42,11 +43,24 @@ def _setup_lm_cli(cli: Cli) -> None: help="instruction-finetune a language model", ) + # Preference Finetune + preference_finetune_handler = RecipeCommandHandler( + loader=load_preference_finetuner, + preset_configs=preference_finetune_presets, + default_preset="llama3_1_instruct", + ) + + group.add_command( + name="preference_finetune", + handler=preference_finetune_handler, + help="preference-finetune a language model (e.g. DPO, SimPO).", + ) + # Text Generate text_generate_handler = RecipeCommandHandler( loader=load_text_generator, preset_configs=text_generate_presets, - default_preset="llama3_8b_instruct", + default_preset="llama3_1_8b_instruct", ) group.add_command( @@ -55,15 +69,15 @@ def _setup_lm_cli(cli: Cli) -> None: help="generate text", ) - # Preference Finetune - preference_finetune_handler = RecipeCommandHandler( - loader=load_preference_finetuner, - preset_configs=preference_finetune_presets, - default_preset="llama3_8b_instruct", + # NLL evaluation + nll_eval_handler = RecipeCommandHandler( + loader=load_nll_evaluator, + preset_configs=nll_eval_presets, + default_preset="llama3_1_base_eval", ) group.add_command( - name="preference_finetune", - handler=preference_finetune_handler, - help="preference-finetune a language model", + name="nll_eval", + handler=nll_eval_handler, + help="Evaluate the model and compute NLL loss over a given dataset", ) diff --git a/src/fairseq2/recipes/lm/chatbot.py b/src/fairseq2/recipes/lm/chatbot.py index 26ff3bb76..9a49ecb16 100644 --- a/src/fairseq2/recipes/lm/chatbot.py +++ b/src/fairseq2/recipes/lm/chatbot.py @@ -9,28 +9,32 @@ import sys from argparse import ArgumentParser, Namespace from datetime import timedelta -from typing import List, Optional, final +from typing import final import torch - -from fairseq2.console import get_console -from fairseq2.data.text import load_text_tokenizer -from fairseq2.gang import Gang +from rich.console import Console +from torch import Tensor +from typing_extensions import override + +from fairseq2.chatbots import Chatbot, ChatMessage, create_chatbot +from fairseq2.context import get_runtime_context +from fairseq2.data.text import TextTokenDecoder, TextTokenizer, load_text_tokenizer +from fairseq2.error import InternalError +from fairseq2.gang import Gang, is_torchrun from fairseq2.generation import ( - Chatbot, - ChatMessage, SamplingSequenceGenerator, + SequenceGenerator, TopPSampler, ) from fairseq2.logging import get_log_writer -from fairseq2.models import create_chatbot, load_model +from fairseq2.models import load_model from fairseq2.models.decoder import DecoderModel from fairseq2.recipes.cli import CliCommandHandler -from fairseq2.recipes.logging import setup_basic_logging +from fairseq2.recipes.cluster import ClusterError, ClusterHandler, ClusterResolver from fairseq2.recipes.utils.argparse import parse_dtype -from fairseq2.recipes.utils.environment import default_env_setters -from fairseq2.recipes.utils.setup import check_model_type, setup_gangs -from fairseq2.typing import CPU, override +from fairseq2.recipes.utils.rich import get_console +from fairseq2.recipes.utils.setup import setup_gangs +from fairseq2.typing import CPU from fairseq2.utils.rng import RngBag log = get_log_writer(__name__) @@ -93,36 +97,32 @@ def init_parser(self, parser: ArgumentParser) -> None: help="maximum sequence generation length (default: %(default)s)", ) - clusters = list(default_env_setters.names()) - - clusters.sort() - parser.add_argument( "--cluster", - choices=["auto"] + clusters, default="auto", - help="cluster on which the chatbot runs (default: %(default)s)", + help="cluster on which the recipe runs (default: %(default)s)", ) @override - def __call__(self, args: Namespace) -> None: - setup_basic_logging() + def run(self, parser: ArgumentParser, args: Namespace) -> int: + context = get_runtime_context() + + cluster_handlers = context.get_registry(ClusterHandler) + + cluster_resolver = ClusterResolver(cluster_handlers, is_torchrun=is_torchrun()) # Set up cluster-specific environment variables. - if args.cluster == "auto": - env_setter = default_env_setters.get_for_inferred_cluster() - else: - try: - env_setter = default_env_setters.get(args.cluster) - except RuntimeError: - log.exception("Chatbot is not running on a '{}' cluster.", args.cluster) # fmt: skip + try: + cluster_handler = cluster_resolver.get(args.cluster) + except LookupError: + log.exception("Chatbot is not running on a '{}' cluster.", args.cluster) # fmt: skip - sys.exit(1) + sys.exit(1) try: - env_setter.set_torch_distributed_env() - except RuntimeError: - log.exception("'{}' cluster environment cannot be set.", env_setter.cluster) # fmt: skip + cluster_handler.set_torch_distributed_variables() + except ClusterError: + log.exception("'{}' cluster environment cannot be set.", args.cluster) # fmt: skip sys.exit(1) @@ -147,7 +147,10 @@ def __call__(self, args: Namespace) -> None: model = load_model(args.model_name, gangs=gangs, dtype=args.dtype) - check_model_type(model, DecoderModel) + if not isinstance(model, DecoderModel): + log.exception("The model must be of type `{}`, but is of type `{}` instead.", DecoderModel, type(model)) # fmt: skip + + sys.exit(1) log.info("Model loaded.") @@ -158,16 +161,41 @@ def __call__(self, args: Namespace) -> None: model, sampler, temperature=args.temperature, max_gen_len=args.max_gen_len # type: ignore[arg-type] ) - chatbot = create_chatbot(generator, tokenizer) + if model.family is None: + log.error("The model has no family name defined.") + + sys.exit(1) + + try: + chatbot = create_chatbot(model.family, generator, tokenizer) + except LookupError: + log.exception("The chatbot cannot be created.") + + sys.exit(1) rng_bag = RngBag.from_device_defaults(CPU, root_gang.device) # Set the seed for sequence generation. rng_bag.manual_seed(args.seed) - self._do_run(args.model_name, chatbot, root_gang) + self._do_run( + args.model_name, + chatbot, + generator, + tokenizer, + root_gang, + ) + + return 0 - def _do_run(self, chatbot_name: str, chatbot: Chatbot, gang: Gang) -> None: + def _do_run( + self, + chatbot_name: str, + chatbot: Chatbot, + generator: SequenceGenerator, + tokenizer: TextTokenizer, + gang: Gang, + ) -> None: dialog = [] if gang.rank == 0: @@ -201,7 +229,10 @@ def _do_run(self, chatbot_name: str, chatbot: Chatbot, gang: Gang) -> None: console.print(f"\n[blue bold]{chatbot_name}> ", end="") - response, _ = chatbot(dialog, stdout=True) + hook = PrintHook(console, tokenizer) + + with generator.register_step_hook(hook): + response, _ = chatbot(dialog) console.print("\n") @@ -216,7 +247,7 @@ def _do_run(self, chatbot_name: str, chatbot: Chatbot, gang: Gang) -> None: raise else: while True: - message_buffer: List[Optional[ChatMessage]] = [None] + message_buffer: list[object] = [None] gang.broadcast_objects(message_buffer) @@ -235,3 +266,65 @@ def _do_run(self, chatbot_name: str, chatbot: Chatbot, gang: Gang) -> None: response, _ = chatbot(dialog) dialog.append(response) + + +@final +class PrintHook: + _console: Console + _text_decoder: TextTokenDecoder + _first_print: bool + _prev_text_len: int + + def __init__(self, console: Console, tokenizer: TextTokenizer) -> None: + self._console = console + self._text_decoder = tokenizer.create_decoder() + self._first_print = True + self._prev_text_len = 0 + + def __call__( + self, + prompt_indices: Tensor, + seqs: Tensor, + step_scores: Tensor | None, + prefill: bool, + ) -> None: + if len(prompt_indices) != 1: + raise InternalError( + f"The length of `prompt_indices` is {len(prompt_indices)}." + ) + + # Do not print anything during prompt prefill. + if prefill: + return + + text = self._text_decoder(seqs[0]) + + text_len = len(text) + + # If this is our first print, determine the length of the prompt text. + if self._prev_text_len == 0: + prev_text = self._text_decoder(seqs[0][:-1]) + + prev_text_len = len(prev_text) + else: + prev_text_len = self._prev_text_len + + # Cache the length of the text so that we don't have to decode it twice + # in the next step. + self._prev_text_len = text_len + + # No need to print if we decoded a control symbol (e.g. EOS). + if text_len == prev_text_len: + return + + text = text[prev_text_len - text_len :] + + # Some models output several whitespace characters after the prompt. + if self._first_print: + text = text.lstrip() + if not text: + return + + self._first_print = False + + self._console.print(text, highlight=False, end="") diff --git a/src/fairseq2/recipes/lm/eval_nll.py b/src/fairseq2/recipes/lm/eval_nll.py new file mode 100644 index 000000000..5465aa742 --- /dev/null +++ b/src/fairseq2/recipes/lm/eval_nll.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import torch + +from fairseq2.assets import AssetNotFoundError +from fairseq2.config_registry import ConfigRegistry +from fairseq2.data.text import load_text_tokenizer +from fairseq2.datasets import LengthBatching +from fairseq2.datasets.instruction import ( + GenericInstructionDataset, + InstructionReadOptions, + load_instruction_dataset, +) +from fairseq2.logging import get_log_writer +from fairseq2.models import load_model +from fairseq2.models.decoder import DecoderModel +from fairseq2.models.sequence import SequenceBatch +from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa +from fairseq2.recipes.evaluator import Evaluator +from fairseq2.recipes.lm.instruction_finetune import ( + InstructionFinetuneCriterion, + InstructionValidUnit, +) +from fairseq2.recipes.utils.asset import ( + AssetReference, + asset_as_path, + retrieve_asset_card, +) +from fairseq2.recipes.utils.log import log_model +from fairseq2.recipes.utils.setup import setup_gangs, to_data_parallel +from fairseq2.typing import CPU, META, DataType +from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed + +log = get_log_writer(__name__) + + +@dataclass(kw_only=True) +class NLLEvalConfig: + """Holds configuration of the perplexity evaluator recipe""" + + # Data + dataset: AssetReference = "foo" + """The name, path or path to the asset card of the dataset to evaluate on.""" + + model: AssetReference = "llama3_1_8b" + """The name or path to the asset card of the wav2vec 2.0 model to evaluate.""" + + checkpoint_dir: Path | None = None + """The checkpoint directory containing models saved by :class:`FileCheckpointManager`.""" + + dtype: DataType = torch.bfloat16 + """The data type of the model.""" + + data_parallelism: Literal["ddp", "fsdp"] = "fsdp" + """The data parallelism API to use.""" + + fsdp_wrap_granularity: Literal["layer", "stack", "model"] = "layer" + """The granularity at which to wrap the model.""" + + fsdp_reshard_after_forward: bool = True + """If ``True``, reshards the parameters only after the backward pass.""" + + tensor_parallel_size: int = 1 + """The size of tensor parallelism.""" + + mixed_precision: Literal["none", "static", "dynamic"] = "static" + """ + If 'none', the whole training will be run in `dtype`. If 'static', forward + and backward passes will be run in `dtype`, but the optimizer step will be + run in full precision. If 'dynamic', forward and backward passes will be run + with `torch.amp` in `dtype`, but the optimizer step will be run in full + precision. + """ + + max_num_tokens: int = 8192 * 2 + """The maximum number of tokens per batch.""" + + min_seq_len: int = 1 + """The minimum sequence length.""" + + max_seq_len: int = 8192 + """The maximum sequence length.""" + + valid_split: str = "default" + """The name of the valid data split.""" + + example_shuffle_window: int = 10_000 + """The size of the sliding window for shuffling examples.""" + + batch_shuffle_window: int = 1000 + """The size of the sliding window for shuffling batches.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + seed: int = 2 + """The random number generator seed to use.""" + + +nll_eval_presets = ConfigRegistry[NLLEvalConfig]() + +nll_eval_preset = nll_eval_presets.decorator + + +@nll_eval_preset("llama3_1_base_eval") +def _llama3_1_base_eval() -> NLLEvalConfig: + return NLLEvalConfig() + + +@torch.inference_mode() +def load_nll_evaluator( + config: NLLEvalConfig, output_dir: Path +) -> Evaluator[SequenceBatch]: + wall_watch = Stopwatch(start=True) + + root_gang, gangs = setup_gangs(log, tp_size=config.tensor_parallel_size) + + dp_gang = gangs["dp"] # data + tp_gang = gangs["tp"] # tensor + + # Load the tokenizer. + model_card = retrieve_asset_card(config.model) + + log.info("Loading {} tokenizer.", model_card.name) + + tokenizer = load_text_tokenizer(model_card) + + log.info("Tokenizer loaded.") + + # Load the dataset. + try: + dataset_card = retrieve_asset_card(config.dataset) + except AssetNotFoundError: + dataset_card = None + + if dataset_card is not None: + log.info("Loading {} preference optimization dataset.", dataset_card.name) + + dataset = load_instruction_dataset(dataset_card) + + log.info("Dataset loaded.") + else: + dataset_path = asset_as_path(config.dataset) + + dataset = GenericInstructionDataset.from_path(dataset_path) + + seed = config.seed + + # Load the model + manual_seed(seed, CPU, root_gang.device) + + seed += 1 + + init_device = META + + dtype = config.dtype if config.mixed_precision == "none" else torch.float32 + + gangs = {"dp": dp_gang, "tp": tp_gang} + + model_card = retrieve_asset_card(config.model) + + log.info("Loading {} model on data parallel rank 0 (per shard).", model_card.name) # fmt: skip + + if dp_gang.rank == 0: + init_device = root_gang.device + model = load_model( + model_card, + gangs=gangs, + device=init_device, + dtype=dtype, + ) + + root_gang.barrier() + + log.info("Model loaded on rank 0.") + + if not isinstance(model, DecoderModel): + raise ValueError( + f"The model must be of type `{DecoderModel}`, but is of type `{type(model)}` instead." + ) + + mp_dtype = config.dtype if config.mixed_precision == "static" else None + + dp_model = to_data_parallel( + model, + dp_gang, + config.data_parallelism, + log, + fsdp_broadcast_state=True, # loading checkpoints not supported + fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, + fsdp_mixed_precision_dtype=mp_dtype, + fsdp_fp32_reduce=True, + fsdp_wrap_granularity=config.fsdp_wrap_granularity, + ) + + enable_memory_efficient_torch_sdpa(dp_model, False) + + log_model(dp_model, log, rank=root_gang.rank) + + # loading the eval unit + criterion = InstructionFinetuneCriterion(dp_model) + unit = InstructionValidUnit(criterion, dp_gang) + + options = InstructionReadOptions( + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + sync_mode="until_last", + num_accumulate=1, + num_prefetch=config.num_prefetch, + seed=seed, + ) + + data_reader = dataset.create_reader( + config.valid_split, + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + LengthBatching(config.max_num_tokens), + options=options, + ) + + # TODO: Fix once we support static mixed precision on one device. + if config.mixed_precision == "static": + amp = root_gang.size == 1 or config.data_parallelism != "fsdp" + else: + amp = config.mixed_precision == "dynamic" + + # Initialize the evaluator. + return Evaluator[SequenceBatch]( + units=[unit], + data_readers=[data_reader], + root_gang=root_gang, + dtype=config.dtype, + amp=amp, + tb_dir=output_dir.joinpath("tb"), + metrics_dir=output_dir.joinpath("metrics"), + seed=seed, + wall_watch=wall_watch, + ) diff --git a/src/fairseq2/recipes/lm/instruction_finetune.py b/src/fairseq2/recipes/lm/instruction_finetune.py index 0a35c07e1..90a391a16 100644 --- a/src/fairseq2/recipes/lm/instruction_finetune.py +++ b/src/fairseq2/recipes/lm/instruction_finetune.py @@ -6,22 +6,24 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path -from typing import Literal, Optional, Tuple, final +from typing import Any, Literal, final import torch import torch.distributed from torch import Tensor from torch.nn import Module +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import load_text_tokenizer -from fairseq2.datasets import LengthBatching +from fairseq2.datasets import Batching, LengthBatching, StaticBatching from fairseq2.datasets.instruction import ( GenericInstructionDataset, + InstructionReadOptions, load_instruction_dataset, ) from fairseq2.gang import Gang @@ -35,9 +37,10 @@ ) from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa -from fairseq2.optim import AdamW -from fairseq2.optim.lr_scheduler import CosineAnnealingLR +from fairseq2.optim import AdamWConfig, create_optimizer +from fairseq2.optim.lr_scheduler import CosineAnnealingLRConfig, create_lr_scheduler from fairseq2.recipes.common_metrics import SequenceMetricBag +from fairseq2.recipes.evaluator import AbstractEvalUnit from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer from fairseq2.recipes.utils.asset import ( AssetReference, @@ -51,13 +54,14 @@ setup_gangs, to_data_parallel, ) -from fairseq2.typing import META, DataType, override +from fairseq2.typing import CPU, META, DataType from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class InstructionFinetuneConfig: """Holds the configuration of a language model instruction-finetuning task.""" @@ -65,12 +69,27 @@ class InstructionFinetuneConfig: dataset: AssetReference = "foo" # TODO: change! """The name, path, or path to the asset card of the instruction dataset.""" + train_split: str = "default" + """The name of the train data split.""" + + valid_split: str | None = None + """The name of the valid data split.""" + + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 8192 """The maximum sequence length.""" max_num_tokens: int = 8192 * 2 """The maximum number of tokens per batch.""" + batch_size: int | None = None + """If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples.""" + + max_num_valid_tokens: int | None = None + """The maximum number of tokens per validation batch.""" + example_shuffle_window: int = 10_000 """The size of the sliding window for shuffling examples.""" @@ -80,16 +99,44 @@ class InstructionFinetuneConfig: num_prefetch: int = 4 """The number of batches to prefetch in background.""" + src_encode_mode: str = "prompt" + """The encode mode for the prompt, determines what special tokens to add.""" + + tgt_encode_mode: str = "prompt_response" + """The encode mode for the target, determines what special tokens to add.""" + # Model - model: AssetReference = "llama3_8b_instruct" + model: AssetReference = "llama3_1_8b_instruct" """The name or path to the asset card of the language model to finetune.""" + model_config: Any = None + """ + The model configuration overrides. The provided values must be compatible + with the checkpoint; otherwise, the model will fail to load. + """ + dtype: DataType = torch.bfloat16 """The data type of the model.""" + mixed_precision: Literal["none", "static", "dynamic"] = "static" + """ + If 'none', the whole training will be run in `dtype`. If 'static', forward + and backward passes will be run in `dtype`, but the optimizer step will be + run in full precision. If 'dynamic', forward and backward passes will be run + with `torch.amp` in `dtype`, but the optimizer step will be run in full + precision. + """ + data_parallelism: Literal["ddp", "fsdp"] = "fsdp" """The data parallelism API to use.""" + fsdp_local_world_size: int | None = None + """ + If not ``None``, enables hybrid sharding. The model will be fully sharded + within each worker group of size ``local_world_size`` and + will be replicated across groups. + """ + fsdp_wrap_granularity: Literal["layer", "stack", "model"] = "layer" """The granularity at which to wrap the model.""" @@ -106,58 +153,73 @@ class InstructionFinetuneConfig: """If ``True``, applies ``torch.compile()`` to the decoder. (experimental)""" # Optimizer, LR, and Loss - lr: float = 5.5e-06 - """The initial (post-warm-up) learning rate.""" + optimizer: str = "adamw" + """The optimizer.""" - betas: Tuple[float, float] = (0.9, 0.95) - """The coefficients of AdamW.""" - - final_lr_ratio: float = 0.2 - """The ratio of the final learning rate to :attr:`lr`.""" + optimizer_config: Any = field( + default_factory=lambda: AdamWConfig( + lr=5.5e-06, betas=(0.9, 0.95), weight_decay=0.1 + ) + ) + """The configuration of the optimizer.""" - weight_decay: float = 0.1 - """The weight decay coefficient of AdamW.""" + lr_scheduler: str = "cosine-annealing" + """The learning rate scheduler.""" - num_lr_warmup_steps: int = 0 - """The number of learning rate warm-up steps.""" + lr_scheduler_config: Any = field( + default_factory=lambda: CosineAnnealingLRConfig(final_lr_scale=0.2) + ) + """The configuration of the learning rate scheduler.""" gradient_accumulation: int = 1 """The number of steps to accumulate gradients before an optimizer update.""" - max_gradient_norm: Optional[float] = None + max_gradient_norm: float | None = None """The maximum gradient norm. If ``None``, no clipping will be applied.""" - fp16_loss_scale: Tuple[float, float] = (128.0, 0.0001) + fp16_loss_scale: tuple[float, float] = (128.0, 0.0001) """The initial and minimum loss scale for fp16 training.""" # Regime max_num_steps: int = 5000 - """The maximum number of steps to train for.""" + """The maximum number of steps to train for. Note that max_num_steps is used as CosineLRScheduler argument!""" - max_num_data_epochs: Optional[int] = None + max_num_data_epochs: int | None = None """The maximum number of data epochs to train for.""" + validate_after_n_steps: int = 0 + """The number of steps after which to start validating the model.""" + + validate_every_n_steps: int = 100 + """The step interval at which to validate the model.""" + checkpoint_every_n_steps: int = 1000 """The step interval at which to checkpoint.""" - keep_last_n_checkpoints: Optional[int] = 1 + checkpoint_every_n_data_epochs: int | None = None + """The data epoch interval at which to checkpoint.""" + + keep_last_n_checkpoints: int | None = 1 """The number of checkpoints to keep. If ``None``, none will be deleted.""" - keep_last_n_models: Optional[int] = None + keep_last_n_models: int | None = None """The number of checkpoint models to keep. If ``None``, none will be deleted.""" publish_metrics_every_n_steps: int = 10 """The step interval at which to publish training metrics.""" - # Checkpointing - resume_checkpoint_dir: Optional[Path] = None + publish_metrics_every_n_data_epochs: int | None = None + """The data epoch interval at which to publish training metrics.""" + + # Checkpoint + resume_checkpoint_dir: Path | None = None """If not ``None``, adds the specified path to the default asset store.""" # Misc seed: int = 2 """The random number generator seed to use.""" - profile: Optional[Tuple[int, int]] = None + profile: tuple[int, int] | None = None """The number of steps that the PyTorch profiler should skip and then record.""" monitored_gang: bool = False @@ -166,18 +228,63 @@ class InstructionFinetuneConfig: anomaly_detection: bool = False """If ``True``, turns on anomaly detection feature in ``torch.autograd``.""" + wandb_project: str | None = None + """If not ``None``, sets the project name for W&B logging.""" + + wandb_run_name: str | None = None + """If not ``None``, sets the run name for W&B logging. If None, then W&B creates a random name.""" + instruction_finetune_presets = ConfigRegistry[InstructionFinetuneConfig]() instruction_finetune_preset = instruction_finetune_presets.decorator +@dataclass(kw_only=True) +class DropoutConfig: + dropout_p: float = 0.0 + + +@instruction_finetune_preset("llama3_1_instruct") +def _llama3_1_instruct() -> InstructionFinetuneConfig: + config = InstructionFinetuneConfig() + config.model_config = DropoutConfig() + return config + + +@instruction_finetune_preset("llama3_1_instruct_constant_lr") +def _llama3_1_instruct_constant_lr() -> InstructionFinetuneConfig: + config = _llama3_1_instruct() + # setting up final lr to be the optmiizer base lr, lr_mul is 1.0 by default + config.lr_scheduler_config.final_lr = config.optimizer_config.lr + return config + + +@instruction_finetune_preset("llama3_1_instruct_lr_anneal_0") +def _llama3_1_instruct_lr_anneal_0() -> InstructionFinetuneConfig: + config = _llama3_1_instruct() + # setting up final lr to be 0.0 at the end of the cycle + config.lr_scheduler_config.final_lr = 0.0 + return config + + +@instruction_finetune_preset("llama3_1_70b_instruct") +def _llama3_1_70b_instruct() -> InstructionFinetuneConfig: + config = _llama3_1_instruct() + + config.model = "llama3_1_70b_instruct" + config.tensor_parallel_size = 8 + + return config + + @instruction_finetune_preset("llama2_7b_chat") def _llama2_7b_chat() -> InstructionFinetuneConfig: - config = _llama3_8b_instruct() + config = _llama3_1_instruct() config.max_seq_len = 4096 config.max_num_tokens = 4096 * 2 + config.max_num_valid_tokens = 4096 * 2 config.model = "llama2_7b_chat" return config @@ -193,21 +300,6 @@ def _llama2_70b_chat() -> InstructionFinetuneConfig: return config -@instruction_finetune_preset("llama3_8b_instruct") -def _llama3_8b_instruct() -> InstructionFinetuneConfig: - return InstructionFinetuneConfig() - - -@instruction_finetune_preset("llama3_70b_instruct") -def _llama3_70b_instruct() -> InstructionFinetuneConfig: - config = _llama3_8b_instruct() - - config.model = "llama3_70b_instruct" - config.tensor_parallel_size = 8 - - return config - - def load_instruction_finetuner( config: InstructionFinetuneConfig, output_dir: Path ) -> Trainer[SequenceBatch]: @@ -230,8 +322,6 @@ def load_instruction_finetuner( CheckpointModelMetadataProvider(config.resume_checkpoint_dir) ) - seed = config.seed - model_card = retrieve_asset_card(config.model) # Load the tokenizer. @@ -258,15 +348,27 @@ def load_instruction_finetuner( dataset = GenericInstructionDataset.from_path(dataset_path) - # Load the model. + seed = config.seed + + # Load the model + manual_seed(seed, CPU, root_gang.device) + + seed += 1 + init_device = META + dtype = config.dtype if config.mixed_precision == "none" else torch.float32 + has_checkpoint = checkpoint_manager.has_checkpoint() if has_checkpoint: try: model = load_model( - model_card, gangs=gangs, device=init_device, dtype=torch.float32 + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, ) except ValueError as ex: raise ValueError( @@ -282,7 +384,11 @@ def load_instruction_finetuner( try: model = load_model( - model_card, gangs=gangs, device=init_device, dtype=torch.float32 + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, ) except ValueError as ex: raise ValueError( @@ -293,23 +399,26 @@ def load_instruction_finetuner( log.info("Model loaded on data parallel rank 0.") - check_model_type(model, DecoderModel) + if not isinstance(model, DecoderModel): + raise ValueError( + f"The model must be of type `{DecoderModel}`, but is of type `{type(model)}` instead." + ) - checkpoint_manager.save_model_metadata( - base_asset=model_card.name, family=model.family - ) + checkpoint_manager.save_model_metadata(base_asset=model_card.name) + + mp_dtype = config.dtype if config.mixed_precision == "static" else None dp_model = to_data_parallel( model, dp_gang, config.data_parallelism, log, - fsdp_skip_init=True, fsdp_broadcast_state=not has_checkpoint, fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, - fsdp_mixed_precision_dtype=config.dtype, + fsdp_mixed_precision_dtype=mp_dtype, fsdp_fp32_reduce=True, fsdp_wrap_granularity=config.fsdp_wrap_granularity, + fsdp_local_world_size=config.fsdp_local_world_size, ) if config.activation_checkpointing: @@ -325,36 +434,128 @@ def load_instruction_finetuner( log_model(dp_model, log, rank=root_gang.rank) - # Initialize the train unit and the optimizer. - unit = InstructionFinetuneUnit(dp_model, dp_gang) + # Initialize the criterion. + criterion = InstructionFinetuneCriterion(dp_model) - data_reader = dataset.create_reader( - tokenizer, - dp_gang, - config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), - example_shuffle_window=config.example_shuffle_window, - batch_shuffle_window=config.batch_shuffle_window, - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - seed=seed, - ) + # Initialize the unit. + unit = InstructionFinetuneUnit(criterion, dp_gang) + + try: + batching: Batching + + if config.batch_size is not None: + batching = StaticBatching(config.batch_size) + else: + batching = LengthBatching(config.max_num_tokens) + + options = InstructionReadOptions( + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, + seed=seed, + ) + + data_reader = dataset.create_reader( + config.train_split, + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + batching, + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex seed += 1 - optimizer = AdamW( - model.parameters(), - lr=config.lr, - betas=config.betas, - weight_decay=config.weight_decay, - ) + # Initialize the optimizer. + try: + optimizer = create_optimizer( + config.optimizer, dp_model, config.optimizer_config + ) + except ValueError as ex: + raise ValueError( + "The optimizer cannot be created. See nested exception for details." + ) from ex - lr_scheduler = CosineAnnealingLR( - optimizer, - cycle_len=config.max_num_steps - config.num_lr_warmup_steps, - num_warmup_steps=config.num_lr_warmup_steps, - final_lr=config.lr * config.final_lr_ratio, - ) + # Initialize the learning rate scheduler. + try: + lr_scheduler = create_lr_scheduler( + config.lr_scheduler, + optimizer, + config.lr_scheduler_config, + max_num_steps=config.max_num_steps, + ) + except ValueError as ex: + raise ValueError( + "The learning rate scheduler cannot be created. See nested exception for details." + ) from ex + + # Initialize the validation unit. + if config.valid_split is not None: + valid_unit = InstructionValidUnit(criterion, dp_gang) + + max_num_tokens = config.max_num_valid_tokens or config.max_num_tokens + + options = InstructionReadOptions( + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + sync_mode="until_last", + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, + seed=seed, + ) + + try: + valid_data_reader = dataset.create_reader( + config.valid_split, + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + LengthBatching(max_num_tokens), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + + valid_units = [valid_unit] + + valid_data_readers = [valid_data_reader] + else: + valid_units = None + + valid_data_readers = None + + seed += 1 + + # TODO: Fix once we support static mixed precision on one device. + if config.mixed_precision == "static": + amp = root_gang.size == 1 or config.data_parallelism != "fsdp" + else: + amp = config.mixed_precision == "dynamic" + + if config.wandb_project is not None: + if config.wandb_run_name is None: + raise ValueError( + "`wandb_run_name` must be specified when `wandb_project` is set." + ) + + wandb_dir = output_dir.joinpath("wandb") + + wandb_options = (wandb_dir, config.wandb_project, config.wandb_run_name) + else: + wandb_options = None # Initialize the trainer. return Trainer[SequenceBatch]( @@ -368,15 +569,23 @@ def load_instruction_finetuner( lr_scheduler=lr_scheduler, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, + valid_units=valid_units, + valid_data_readers=valid_data_readers, + validate_after_n_steps=config.validate_after_n_steps, + validate_every_n_steps=config.validate_every_n_steps, checkpoint_manager=checkpoint_manager, checkpoint_every_n_steps=config.checkpoint_every_n_steps, + checkpoint_every_n_data_epochs=config.checkpoint_every_n_data_epochs, keep_last_n_checkpoints=config.keep_last_n_checkpoints, keep_last_n_models=config.keep_last_n_models, tb_dir=output_dir.joinpath("tb"), metrics_dir=output_dir.joinpath("metrics"), + wandb_options=wandb_options, publish_metrics_every_n_steps=config.publish_metrics_every_n_steps, + publish_metrics_every_n_data_epochs=config.publish_metrics_every_n_data_epochs, profile=config.profile, anomaly_detection=config.anomaly_detection, seed=seed, @@ -386,25 +595,60 @@ def load_instruction_finetuner( @final class InstructionFinetuneUnit(AbstractTrainUnit[SequenceBatch]): - """Represents a language model instruction-finetuning unit.""" + _criterion: InstructionFinetuneCriterion + _metric_bag: SequenceMetricBag + + def __init__(self, criterion: InstructionFinetuneCriterion, gang: Gang) -> None: + super().__init__(criterion.model) + + self._criterion = criterion + + self._metric_bag = SequenceMetricBag(gang) + + @override + def __call__(self, batch: SequenceBatch) -> tuple[Tensor, int]: + return self._criterion(batch, self._metric_bag) + @property + @override + def metric_bag(self) -> SequenceMetricBag: + return self._metric_bag + + +@final +class InstructionValidUnit(AbstractEvalUnit[SequenceBatch]): + _criterion: InstructionFinetuneCriterion _metric_bag: SequenceMetricBag - def __init__(self, model: Module, gang: Gang) -> None: - """ - :param model: - The language model. Might be wrapped with DDP or FSDP. - :param gang: - The gang for distributed training. - """ - super().__init__(model) + def __init__(self, criterion: InstructionFinetuneCriterion, gang: Gang) -> None: + super().__init__(criterion.model) - check_model_type(model, DecoderModel) + self._criterion = criterion self._metric_bag = SequenceMetricBag(gang) @override - def __call__(self, batch: SequenceBatch) -> Tuple[Tensor, int]: + def __call__(self, batch: SequenceBatch) -> None: + self._criterion(batch, self._metric_bag) + + @property + @override + def metric_bag(self) -> SequenceMetricBag: + return self._metric_bag + + +@final +class InstructionFinetuneCriterion: + _model: Module + + def __init__(self, model: Module) -> None: + check_model_type(model, DecoderModel) + + self._model = model + + def __call__( + self, batch: SequenceBatch, metric_bag: SequenceMetricBag + ) -> tuple[Tensor, int]: input_batch, target_batch = as_auto_regressive_input(batch) output = self._forward(input_batch) @@ -413,9 +657,9 @@ def __call__(self, batch: SequenceBatch) -> Tuple[Tensor, int]: target_batch.seqs, loss_mask=target_batch.target_mask ) - self._metric_bag.update_nll_loss(target_batch, loss.detach()) + metric_bag.update_nll_loss(target_batch, loss) - self._metric_bag.update_batch_metrics(target_batch) + metric_bag.update_batch_metrics(target_batch) return loss, target_batch.num_target_elements() @@ -423,6 +667,5 @@ def _forward(self, batch: SequenceBatch) -> SequenceModelOutput: return self._model(batch) # type: ignore[no-any-return] @property - @override - def metric_bag(self) -> SequenceMetricBag: - return self._metric_bag + def model(self) -> Module: + return self._model diff --git a/src/fairseq2/recipes/lm/preference_finetune.py b/src/fairseq2/recipes/lm/preference_finetune.py deleted file mode 100644 index 2b2bab768..000000000 --- a/src/fairseq2/recipes/lm/preference_finetune.py +++ /dev/null @@ -1,443 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from pathlib import Path -from typing import Literal, Optional, Tuple, Union - -import torch -import torch.distributed -from torch.nn import Module - -from fairseq2.assets import AssetNotFoundError, default_asset_store -from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager -from fairseq2.config_registry import ConfigRegistry -from fairseq2.data.text import load_text_tokenizer -from fairseq2.datasets import LengthBatching -from fairseq2.datasets.preference import ( - GenericPreferenceOptimizationDataset, - PreferenceOptimizationBatch, - load_preference_optimization_dataset, -) -from fairseq2.logging import get_log_writer -from fairseq2.models import load_model -from fairseq2.models.decoder import DecoderModel -from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing -from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa -from fairseq2.nn.utils.module import freeze_parameters -from fairseq2.optim import AdamW -from fairseq2.optim.lr_scheduler import CosineAnnealingLR -from fairseq2.recipes.lm.preference_units.dpo_unit import ( - DpoFinetuneConfig, - DpoFinetuneUnit, -) -from fairseq2.recipes.lm.preference_units.simpo_unit import ( - SimpoFinetuneConfig, - SimpoFinetuneUnit, -) -from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer -from fairseq2.recipes.utils.asset import retrieve_asset_card -from fairseq2.recipes.utils.log import log_model -from fairseq2.recipes.utils.setup import ( - broadcast_model, - compile_model, - setup_gangs, - to_data_parallel, -) -from fairseq2.typing import META, DataType -from fairseq2.utils.profiler import Stopwatch - -log = get_log_writer(__name__) - - -@dataclass -class PreferenceOptimizationConfig: - """Holds the configuration of a language model preference-finetuning task.""" - - # Data - dataset: Union[str, Path] = "openeft" # TODO: change! - """The name, path, or path to the asset card of the preference optimization dataset.""" - - max_seq_len: int = 8192 - """The maximum sequence length.""" - - max_num_tokens: int = 8192 * 2 - """The maximum number of tokens per batch.""" - - example_shuffle_window: int = 10_000 - """The size of the sliding window for shuffling examples.""" - - batch_shuffle_window: int = 1000 - """The size of the sliding window for shuffling batches.""" - - num_prefetch: int = 4 - """The number of batches to prefetch in background.""" - - # Criterion - criterion: Literal["dpo", "simpo"] = "dpo" - """The type of preference optimization to perform.""" - - dpo: DpoFinetuneConfig = field(default_factory=lambda: DpoFinetuneConfig()) - """The configuration for Direct Preference Optimization.""" - - simpo: SimpoFinetuneConfig = field(default_factory=lambda: SimpoFinetuneConfig()) - """The configuration for SimPO.""" - - # Model - model: Union[str, Path] = "llama3_8b_instruct" - """The name or path to the asset card of the language model to finetune.""" - - dtype: DataType = torch.bfloat16 - """The data type of the model.""" - - data_parallelism: Literal["ddp", "fsdp"] = "fsdp" - """The data parallelism API to use.""" - - fsdp_wrap_granularity: Literal["layer", "stack", "model"] = "layer" - """The granularity at which to wrap the model.""" - - fsdp_reshard_after_forward: bool = True - """If ``True``, reshards the parameters only after the backward pass.""" - - tensor_parallel_size: int = 1 - """The size of tensor parallelism.""" - - activation_checkpointing: bool = True - """If ``True``, uses layer-wise activation checkpointing.""" - - torch_compile: bool = False - """If ``True``, applies ``torch.compile()`` to the decoder. (experimental)""" - - # Optimizer, LR, and Loss - lr: float = 5.5e-06 - """The initial (post-warm-up) learning rate.""" - - betas: Tuple[float, float] = (0.9, 0.95) - """The coefficients of AdamW.""" - - final_lr_ratio: float = 0.2 - """The ratio of the final learning rate to :attr:`lr`.""" - - weight_decay: float = 0.1 - """The weight decay coefficient of AdamW.""" - - num_lr_warmup_steps: int = 0 - """The number of learning rate warm-up steps.""" - - gradient_accumulation: int = 1 - """The number of steps to accumulate gradients before an optimizer update.""" - - max_gradient_norm: Optional[float] = None - """The maximum gradient norm. If ``None``, no clipping will be applied.""" - - fp16_loss_scale: Tuple[float, float] = (128.0, 0.0001) - """The initial and minimum loss scale for fp16 training.""" - - # Regime - max_num_steps: int = 5000 - """The maximum number of steps to train for.""" - - max_num_data_epochs: Optional[int] = None - """The maximum number of data epochs to train for.""" - - checkpoint_every_n_steps: int = 1000 - """The step interval at which to checkpoint.""" - - keep_last_n_checkpoints: Optional[int] = 1 - """The number of checkpoints to keep. If ``None``, none will be deleted.""" - - keep_last_n_models: Optional[int] = None - """The number of checkpoint models to keep.""" - - publish_metrics_every_n_steps: int = 10 - """The step interval at which to publish training metrics.""" - - # Checkpointing - resume_checkpoint_dir: Optional[Path] = None - """If not ``None``, adds the specified path to the default asset store.""" - - # Misc - seed: int = 2 - """The random number generator seed to use.""" - - profile: Optional[Tuple[int, int]] = None - """The number of steps that the PyTorch profiler should skip and then record.""" - - monitored_gang: bool = False - """If ``True``, puts a monitored barrier before every collective call.""" - - anomaly_detection: bool = False - """If ``True``, turns on anomaly detection feature in ``torch.autograd``.""" - - -preference_finetune_presets = ConfigRegistry[PreferenceOptimizationConfig]() - -preference_finetune_preset = preference_finetune_presets.decorator - - -@preference_finetune_preset("simpo") -def _simpo() -> PreferenceOptimizationConfig: - cfg = PreferenceOptimizationConfig() - cfg.max_num_tokens = 1200 - cfg.max_seq_len = 600 - cfg.model = "llama3_8b" - cfg.simpo = SimpoFinetuneConfig() - return cfg - - -@preference_finetune_preset("llama3_8b_instruct") -def _llama3_8b_instruct() -> PreferenceOptimizationConfig: - cfg = PreferenceOptimizationConfig() - cfg.max_num_tokens = 1000 - cfg.max_seq_len = 1000 - cfg.max_gradient_norm = 1.0 - return cfg - - -# batch size and min lengths are tuned for OA2 in this preset! -@preference_finetune_preset("llama3_70b_instruct_openassistant2") -def _llama3_70b_instruct_openassistant2() -> PreferenceOptimizationConfig: - cfg = PreferenceOptimizationConfig() - cfg.model = "llama3_70b_instruct" - cfg.dpo = DpoFinetuneConfig() - cfg.tensor_parallel_size = 8 - cfg.max_num_tokens = ( - 200 # 70B DPO training might catch OOM, tune the effective batch size if needed - ) - cfg.max_seq_len = 200 - cfg.max_gradient_norm = 1.0 - cfg.gradient_accumulation = 8 # to address small batch size - return cfg - - -def load_preference_finetuner( - config: PreferenceOptimizationConfig, output_dir: Path -) -> Trainer[PreferenceOptimizationBatch]: - """Load a :class:`Trainer` for language model preference optimization-finetuning.""" - wall_watch = Stopwatch(start=True) - - root_gang, gangs = setup_gangs( - log, tp_size=config.tensor_parallel_size, monitored=config.monitored_gang - ) - - dp_gang = gangs["dp"] # data - tp_gang = gangs["tp"] # tensor - - checkpoint_manager = FileCheckpointManager( - output_dir.joinpath("checkpoints"), root_gang, dp_gang=dp_gang, tp_gang=tp_gang - ) - - if config.resume_checkpoint_dir is not None: - default_asset_store.metadata_providers.append( - CheckpointModelMetadataProvider(config.resume_checkpoint_dir) - ) - - # Load the tokenizer. - model_card = retrieve_asset_card(config.model) - - log.info("Loading {} tokenizer.", model_card.name) - - tokenizer = load_text_tokenizer(model_card) - - log.info("Tokenizer loaded.") - - # Load the dataset. - try: - dataset_card = retrieve_asset_card(config.dataset) - except AssetNotFoundError: - dataset_card = None - - if dataset_card is not None: - log.info("Loading {} preference optimization dataset.", dataset_card.name) - - dataset = load_preference_optimization_dataset(dataset_card) - - log.info("Dataset loaded.") - else: - try: - path = Path(config.dataset) - except ValueError: - raise AssetNotFoundError( - config.dataset, f"An asset with the name '{config.dataset}' cannot be found." # type: ignore[arg-type] - ) - - dataset = GenericPreferenceOptimizationDataset.from_path(path) - log.info("Dataset loaded from path {}.", path) - - # Load the model. - init_device = META - - has_checkpoint = checkpoint_manager.has_checkpoint() - - if has_checkpoint: - model = load_model( - model_card, gangs=gangs, device=init_device, dtype=torch.float32 - ) - # If we don't have a checkpoint, load the pretrained model on rank 0 and - # broadcast it to the gang. - else: - log.info("Loading {} model on data parallel rank 0 (per shard).", model_card.name) # fmt: skip - - if dp_gang.rank == 0: - init_device = root_gang.device - - model = load_model( - model_card, gangs=gangs, device=init_device, dtype=torch.float32 - ) - - root_gang.barrier() - - log.info("Model loaded on data parallel rank 0.") - - if not isinstance(model, DecoderModel): - raise ValueError("`config.model` must specify a decoder model.") - - checkpoint_manager.save_model_metadata( - base_asset=model_card.name, family=model.family - ) - - dp_model = to_data_parallel( - model, - dp_gang, - config.data_parallelism, - log, - fsdp_skip_init=True, - fsdp_broadcast_state=not has_checkpoint, - fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, - fsdp_mixed_precision_dtype=config.dtype, - fsdp_fp32_reduce=True, - fsdp_wrap_granularity=config.fsdp_wrap_granularity, - ) - - if config.activation_checkpointing: - use_layerwise_activation_checkpointing(dp_model) - - if config.torch_compile: - model.decoder = compile_model(model.decoder, log) - - # TODO(balioglu): investigate! - # The memory efficient SDPA implementation in PyTorch is not stable when - # used with padded inputs. - enable_memory_efficient_torch_sdpa(dp_model, False) - - log_model(dp_model, log, rank=root_gang.rank) - - # Load the reference model. - def _load_reference_model( - reference_model_path: Union[str, Path], - reference_dtype: DataType, - reference_tensor_parallel_size: int, - ) -> Module: - reference_model_card = retrieve_asset_card(reference_model_path) - - log.info("Loading {} reference model on data parallel rank 0 (per shard).", reference_model_card.name) # fmt: skip - - if dp_gang.rank == 0: - init_device = dp_gang.device - else: - init_device = META - - # TODO: figure out how to load the reference model onto its own gangs - reference_model = load_model( - reference_model_card, - gangs=gangs, - device=init_device, - dtype=reference_dtype, - ) - - root_gang.barrier() - - log.info("Reference model loaded on data parallel rank 0.") - - reference_model.eval() - - freeze_parameters(reference_model) - - if dp_gang.size != 1: - broadcast_model(reference_model, dp_gang, log) - - return reference_model - - def _create_preference_unit( - config: PreferenceOptimizationConfig, - ) -> AbstractTrainUnit[PreferenceOptimizationBatch]: - # TODO: setup registers for TrainUnits to replace this - if config.criterion == "dpo": - dp_reference_model = _load_reference_model( - config.dpo.reference_model, - config.dpo.reference_dtype, - config.dpo.reference_tensor_parallel_size, - ) - return DpoFinetuneUnit( - dp_model, - dp_reference_model, - dp_gang, - config.dpo.dpo_beta, - config.dpo.nll_scale, - ) - if config.criterion == "simpo": - return SimpoFinetuneUnit( - dp_model, - dp_gang, - config.simpo.simpo_beta, - config.simpo.simpo_gamma, - config.simpo.nll_scale, - ) - # TODO: build an exception for this. is there one already? - raise ValueError(f"config.criterion_type '{config.criterion}' cannot be found.") - - # Initialize the train unit - unit = _create_preference_unit(config) - - data_reader = dataset.create_reader( - tokenizer, - dp_gang, - max_seq_len=config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), - max_num_tokens=config.max_num_tokens, - example_shuffle_window=config.example_shuffle_window, - batch_shuffle_window=config.batch_shuffle_window, - num_accumulate=config.gradient_accumulation, - num_prefetch=config.num_prefetch, - seed=config.seed, - ) - - # Initialize the optimizer - optimizer = AdamW( - model.parameters(), - lr=config.lr, - betas=config.betas, - weight_decay=config.weight_decay, - ) - - lr_scheduler = CosineAnnealingLR( - optimizer, - cycle_len=config.max_num_steps - config.num_lr_warmup_steps, - num_warmup_steps=config.num_lr_warmup_steps, - final_lr=config.lr * config.final_lr_ratio, - ) - - # Initialize the trainer. - return Trainer[PreferenceOptimizationBatch]( - unit=unit, - data_reader=data_reader, - root_gang=root_gang, - dp_gang=dp_gang, - tp_gang=tp_gang, - dtype=config.dtype, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - fp16_loss_scale=config.fp16_loss_scale, - max_gradient_norm=config.max_gradient_norm, - max_num_steps=config.max_num_steps, - max_num_data_epochs=config.max_num_data_epochs, - checkpoint_manager=checkpoint_manager, - checkpoint_every_n_steps=config.checkpoint_every_n_steps, - keep_last_n_checkpoints=config.keep_last_n_checkpoints, - keep_last_n_models=config.keep_last_n_models, - tb_dir=output_dir.joinpath("tb"), - publish_metrics_every_n_steps=config.publish_metrics_every_n_steps, - profile=config.profile, - anomaly_detection=config.anomaly_detection, - seed=config.seed, - wall_watch=wall_watch, - ) diff --git a/src/fairseq2/recipes/lm/preference_finetune/__init__.py b/src/fairseq2/recipes/lm/preference_finetune/__init__.py new file mode 100644 index 000000000..486ab8b8d --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.recipes.lm.preference_finetune.cpo import CpoConfig as CpoConfig +from fairseq2.recipes.lm.preference_finetune.dpo import DpoConfig as DpoConfig +from fairseq2.recipes.lm.preference_finetune.orpo import OrpoConfig as OrpoConfig +from fairseq2.recipes.lm.preference_finetune.recipe import ( + load_preference_finetuner as load_preference_finetuner, +) +from fairseq2.recipes.lm.preference_finetune.recipe import ( + preference_finetune_presets as preference_finetune_presets, +) +from fairseq2.recipes.lm.preference_finetune.simpo import SimPOConfig as SimPOConfig +from fairseq2.recipes.lm.preference_finetune.utils import ( + preference_unit_factory as preference_unit_factory, +) + +# isort: split + +import fairseq2.recipes.lm.preference_finetune.cpo +import fairseq2.recipes.lm.preference_finetune.dpo +import fairseq2.recipes.lm.preference_finetune.orpo +import fairseq2.recipes.lm.preference_finetune.simpo diff --git a/src/fairseq2/recipes/lm/preference_finetune/cpo.py b/src/fairseq2/recipes/lm/preference_finetune/cpo.py new file mode 100644 index 000000000..9faea7d8e --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/cpo.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, cast, final + +import torch +import torch.distributed +from torch import Tensor +from torch.nn import Module +from torcheval.metrics import Mean +from typing_extensions import override + +from fairseq2.datasets.preference import PreferenceOptimizationBatch +from fairseq2.gang import Gang +from fairseq2.logging import get_log_writer +from fairseq2.metrics.recorder import format_as_float, register_metric_formatter +from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input +from fairseq2.recipes.lm.preference_finetune.utils import ( + PreferenceFinetuneMetricBag, + _gather_lprobs, + preference_unit_factory, +) +from fairseq2.recipes.trainer import AbstractTrainUnit + +log = get_log_writer(__name__) + + +@final +class CpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): + """Represents the language model CPO-finetuning unit. Paper: https://arxiv.org/abs/2401.08417.""" + + _beta: float + _nll_scale: float + _metric_bag: CpoFinetuneMetricBag + + def __init__( + self, + model: Module, + gang: Gang, + beta: float = 1.0, + nll_scale: float = 1.0, + ) -> None: + super().__init__(model) + + self._beta = beta + self._nll_scale = nll_scale + + self._metric_bag = CpoFinetuneMetricBag(gang) + + @override + def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]: + chosen_batch = batch.chosen + chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) + rejected_batch = batch.rejected + rejected_input_batch, rejected_target_batch = as_auto_regressive_input( + rejected_batch + ) + + chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) + rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) + + chosen_logps = _gather_lprobs(chosen_output, chosen_target_batch) + rejected_logps = _gather_lprobs(rejected_output, rejected_target_batch) + + cpo_loss = self._compute_cpo_loss(chosen_logps, rejected_logps) + + nll_loss = chosen_output.compute_loss( + chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask + ) + + self._metric_bag.update_cpo_loss(batch, cpo_loss) + + self._metric_bag.update_nll_loss(chosen_batch, nll_loss) + + self._metric_bag.update_sequence_lengths(batch) + + self._metric_bag.update_logps(batch, chosen_logps, rejected_logps) + + self._metric_bag.update_batch_metrics(chosen_batch) + + loss = ( + cpo_loss + + self._nll_scale + * nll_loss + * chosen_target_batch.batch_size + / chosen_target_batch.num_target_elements() + ) # normalization applied locally per-rank + + return loss, chosen_target_batch.batch_size + + def _compute_cpo_loss( + self, + chosen_logps: Tensor, + rejected_logps: Tensor, + ) -> Tensor: + cpo_loss = -torch.nn.functional.logsigmoid( + self._beta * (chosen_logps - rejected_logps) + ) + return cpo_loss.sum() + + @override + def set_step_nr(self, step_nr: int) -> None: + """Set the current training step number.""" + self._step_nr = step_nr + + @property + @override + def metric_bag(self) -> CpoFinetuneMetricBag: + return self._metric_bag + + +register_metric_formatter("cpo_loss", "CPO Loss", 0, format_as_float) + + +class CpoFinetuneMetricBag(PreferenceFinetuneMetricBag): + """Holds the metrics of a CPO preference finetuning task.""" + + cpo_loss: Mean + + def __init__(self, gang: Gang) -> None: + super().__init__(gang) + + self.register_metric("cpo_loss", Mean(device=gang.device), persistent=False) + + @torch.inference_mode() + def update_cpo_loss(self, batch: PreferenceOptimizationBatch, loss: Tensor) -> None: + """Update the CPO loss metric. + + :param batch: + The batch processed by the model. + :param loss: + The CPO loss of ``batch``. + """ + self.cpo_loss.update( + loss / batch.chosen.batch_size, weight=batch.chosen.batch_size + ) + + +@dataclass(kw_only=True) +class CpoConfig: + """Holds the CPO configuration of a language model preference-finetuning task.""" + + # Hyperparameters + beta: float = 1.0 + """The coefficient applied to the difference between preferred and dispreferred sequences.""" + + nll_scale: float = 1.0 + """The coefficient of NLL loss added to the CPO loss.""" + + +@preference_unit_factory("cpo") +def create_cpo_unit( + config: CpoConfig, model: Module, root_gang: Gang, gangs: Mapping[str, Gang] +) -> CpoFinetuneUnit: + dp_gang = gangs["dp"] # data + + return CpoFinetuneUnit(model, dp_gang, config.beta, config.nll_scale) diff --git a/src/fairseq2/recipes/lm/preference_finetune/dpo.py b/src/fairseq2/recipes/lm/preference_finetune/dpo.py new file mode 100644 index 000000000..ed460d351 --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/dpo.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, cast, final + +import torch +import torch.distributed +from torch import Tensor +from torch.nn import Module +from torcheval.metrics import Mean +from typing_extensions import override + +from fairseq2.datasets.preference import PreferenceOptimizationBatch +from fairseq2.gang import Gang +from fairseq2.logging import get_log_writer +from fairseq2.metrics.recorder import format_as_float, register_metric_formatter +from fairseq2.models.sequence import ( + SequenceBatch, + SequenceModelOutput, + as_auto_regressive_input, +) + +# from fairseq2.recipes.lm.preference_finetune.recipe import preference_unit_factory +from fairseq2.recipes.lm.preference_finetune.utils import ( + PreferenceFinetuneMetricBag, + _gather_lprobs_avg, + _load_reference_model, + preference_unit_factory, +) +from fairseq2.recipes.trainer import AbstractTrainUnit +from fairseq2.recipes.utils.asset import AssetReference +from fairseq2.typing import DataType + +log = get_log_writer(__name__) + + +@final +class DpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): + """Represents the language model DPO-finetuning unit. Paper: https://arxiv.org/abs/2305.18290.""" + + _reference_model: Module | None + _beta: float + _nll_scale: float + _metric_bag: DpoFinetuneMetricBag + _length_normalization: bool + + def __init__( + self, + model: Module, + reference_model: Module | None, + gang: Gang, + beta: float = 0.1, + nll_scale: float = 1.0, + length_normalization: bool = False, + ) -> None: + super().__init__(model) + + self._reference_model = reference_model + self._beta = beta + self._nll_scale = nll_scale + self._length_normalization = length_normalization + + self._metric_bag = DpoFinetuneMetricBag(gang) + + @override + def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]: + chosen_batch = batch.chosen + chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) + rejected_batch = batch.rejected + rejected_input_batch, rejected_target_batch = as_auto_regressive_input( + rejected_batch + ) + if ( + chosen_target_batch.target_mask is None + or rejected_target_batch.target_mask is None + ): + raise RuntimeError("target_mask attributes must exist for DPO loss") + + chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) + rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) + + chosen_logps, average_chosen_logps = _gather_lprobs_avg( + chosen_output, chosen_target_batch + ) + rejected_logps, average_rejected_logps = _gather_lprobs_avg( + rejected_output, rejected_target_batch + ) + + if self._reference_model is not None: + with torch.no_grad(): + ref_chosen_output = cast( + SequenceModelOutput, self._reference_model(chosen_batch) + ) + ref_rejected_output = cast( + SequenceModelOutput, self._reference_model(rejected_batch) + ) + ref_chosen_logps, ref_average_chosen_logps = _gather_lprobs_avg( + ref_chosen_output, chosen_target_batch + ) + ref_rejected_logps, ref_average_rejected_logps = _gather_lprobs_avg( + ref_rejected_output, rejected_target_batch + ) + elif ( + batch.reference_score_chosen is not None + and batch.reference_score_rejected is not None + ): + # reference scores must exist in the batch if reference model is None + ref_chosen_logps = batch.reference_score_chosen + ref_average_chosen_logps = ( + ref_chosen_logps / chosen_target_batch.target_mask.sum(-1) + ) + ref_rejected_logps = batch.reference_score_rejected + ref_average_rejected_logps = ( + ref_rejected_logps / rejected_target_batch.target_mask.sum(-1) + ) + else: + raise RuntimeError( + "Reference model is not initialized and data batch does not provide reference score, but at least one must exist." + ) + + if self._length_normalization: + _, _, dpo_loss = self._compute_dpo_loss( + average_chosen_logps, + ref_average_chosen_logps, + average_rejected_logps, + ref_average_rejected_logps, + ) + else: + _, _, dpo_loss = self._compute_dpo_loss( + chosen_logps, ref_chosen_logps, rejected_logps, ref_rejected_logps + ) + + nll_loss = chosen_output.compute_loss( + chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask + ) + + self._metric_bag.update_dpo_loss(batch, dpo_loss) + + self._metric_bag.update_nll_loss(chosen_batch, nll_loss) + + self._metric_bag.update_sequence_lengths(batch) + + self._metric_bag.update_logps(batch, chosen_logps, rejected_logps) + + self._metric_bag.update_batch_metrics(chosen_batch) + + loss = ( + dpo_loss + + self._nll_scale + * nll_loss + * chosen_target_batch.batch_size + / chosen_target_batch.num_target_elements() + ) # normalization applied locally per-rank + + return loss, chosen_target_batch.batch_size + + def _gather_lprobs( + self, output: SequenceModelOutput, target: SequenceBatch + ) -> tuple[Tensor, Tensor]: + logprobs = torch.log_softmax(output.logits, dim=-1) + per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze( + -1 + ) + total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1] + assert target.target_mask is not None + average_logps = total_logps / target.target_mask.sum(-1) + + return total_logps, average_logps + + def _compute_dpo_loss( + self, + chosen_logps: Tensor, + ref_chosen_logps: Tensor, + rejected_logps: Tensor, + ref_rejected_logps: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + logp_ratio_chosen = self._beta * (chosen_logps - ref_chosen_logps) + logp_ratio_rejected = self._beta * (rejected_logps - ref_rejected_logps) + dpo_loss = -torch.nn.functional.logsigmoid( + logp_ratio_chosen - logp_ratio_rejected + ) + return logp_ratio_chosen, logp_ratio_rejected, dpo_loss.sum() + + @override + def set_step_nr(self, step_nr: int) -> None: + self._step_nr = step_nr + + @property + @override + def metric_bag(self) -> DpoFinetuneMetricBag: + return self._metric_bag + + +register_metric_formatter("dpo_loss", "DPO Loss", 0, format_as_float) + + +class DpoFinetuneMetricBag(PreferenceFinetuneMetricBag): + """Holds the metrics of a DPO preference finetuning task.""" + + dpo_loss: Mean + + def __init__(self, gang: Gang) -> None: + super().__init__(gang) + + self.register_metric("dpo_loss", Mean(device=gang.device), persistent=False) + + @torch.inference_mode() + def update_dpo_loss(self, batch: PreferenceOptimizationBatch, loss: Tensor) -> None: + """Update the DPO loss metric. + + :param batch: + The batch processed by the model. + :param loss: + The DPO loss of ``batch``. + """ + self.dpo_loss.update( + loss / batch.chosen.batch_size, weight=batch.chosen.batch_size + ) + + +@dataclass(kw_only=True) +class DpoConfig: + """Holds the DPO configuration of a language model preference-finetuning task.""" + + # Reference Model + reference_model: AssetReference | None = "llama3_1_8b_instruct" + """The name, path, or path to the asset card of the reference model. If reference_model is None, recipe expects to get reference log-probabilities for chosen and rejected targets as float values in the data example (fields `reference_score_rejected` and `reference_score_chosen`).""" + + reference_dtype: DataType = torch.bfloat16 + """The data type of the reference model.""" + + reference_tensor_parallel_size: int = 1 + """The size of tensor parallelism for the reference model.""" + + # Loss + beta: float = 0.1 + """The coefficient of regularization towards the reference model.""" + + nll_scale: float = 0.0 + """The coefficient of NLL loss added to the DPO loss.""" + + length_normalization: bool = False + """Use length normalized DPO, which uses the average log probability of a sequence as the implicit reward.""" + + +@preference_unit_factory("dpo") +def create_dpo_unit( + config: DpoConfig, model: Module, root_gang: Gang, gangs: Mapping[str, Gang] +) -> DpoFinetuneUnit: + reference_model = None + if config.reference_model is not None: + reference_model = _load_reference_model( + config.reference_model, + config.reference_dtype, + root_gang, + gangs, + config.reference_tensor_parallel_size, + log, + ) + + dp_gang = gangs["dp"] # data + + return DpoFinetuneUnit( + model, + reference_model, + dp_gang, + config.beta, + config.nll_scale, + config.length_normalization, + ) diff --git a/src/fairseq2/recipes/lm/preference_finetune/orpo.py b/src/fairseq2/recipes/lm/preference_finetune/orpo.py new file mode 100644 index 000000000..81bf211fc --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/orpo.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, cast, final + +import torch +import torch.distributed +from torch import Tensor +from torch.nn import Module +from torcheval.metrics import Mean +from typing_extensions import override + +from fairseq2.datasets.preference import PreferenceOptimizationBatch +from fairseq2.gang import Gang +from fairseq2.logging import get_log_writer +from fairseq2.metrics.recorder import format_as_float, register_metric_formatter +from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input +from fairseq2.recipes.lm.preference_finetune.utils import ( + PreferenceFinetuneMetricBag, + _gather_lprobs, + preference_unit_factory, +) +from fairseq2.recipes.trainer import AbstractTrainUnit + +log = get_log_writer(__name__) + + +@final +class OrpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): + """Represents the language model ORPO-finetuning unit. Paper: https://arxiv.org/abs/2403.07691.""" + + _lambda: float + _nll_scale: float + _metric_bag: OrpoFinetuneMetricBag + + def __init__( + self, + model: Module, + gang: Gang, + orpo_lambda: float = 1.0, + nll_scale: float = 1.0, + ) -> None: + super().__init__(model) + + self._lambda = orpo_lambda + self._nll_scale = nll_scale + + self._metric_bag = OrpoFinetuneMetricBag(gang) + + @override + def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]: + chosen_batch = batch.chosen + chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) + rejected_batch = batch.rejected + rejected_input_batch, rejected_target_batch = as_auto_regressive_input( + rejected_batch + ) + + chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) + rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) + + chosen_logps = _gather_lprobs(chosen_output, chosen_target_batch) + rejected_logps = _gather_lprobs(rejected_output, rejected_target_batch) + + orpo_loss = self._compute_orpo_loss(chosen_logps, rejected_logps) + + nll_loss = chosen_output.compute_loss( + chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask + ) + + self._metric_bag.update_orpo_loss(batch, orpo_loss) + + self._metric_bag.update_nll_loss(chosen_batch, nll_loss) + + self._metric_bag.update_sequence_lengths(batch) + + self._metric_bag.update_logps(batch, chosen_logps, rejected_logps) + + self._metric_bag.update_batch_metrics(chosen_batch) + + loss = ( + orpo_loss + + self._nll_scale + * nll_loss + * chosen_target_batch.batch_size + / chosen_target_batch.num_target_elements() + ) # normalization applied locally per-rank + + return loss, chosen_target_batch.batch_size + + def _compute_orpo_loss( + self, + chosen_logps: Tensor, + rejected_logps: Tensor, + ) -> Tensor: + log_odds = (chosen_logps - rejected_logps) - ( + torch.log1p(-torch.exp(chosen_logps)) + - torch.log1p(-torch.exp(rejected_logps)) + ) + + orpo_loss = -torch.nn.functional.logsigmoid(log_odds) + return orpo_loss.sum() + + @override + def set_step_nr(self, step_nr: int) -> None: + """Set the current training step number.""" + self._step_nr = step_nr + + @property + @override + def metric_bag(self) -> OrpoFinetuneMetricBag: + return self._metric_bag + + +register_metric_formatter("orpo_loss", "ORPO Loss", 0, format_as_float) + + +class OrpoFinetuneMetricBag(PreferenceFinetuneMetricBag): + """Holds the metrics of a ORPO preference finetuning task.""" + + orpo_loss: Mean + + def __init__(self, gang: Gang) -> None: + super().__init__(gang) + + self.register_metric("orpo_loss", Mean(device=gang.device), persistent=False) + + @torch.inference_mode() + def update_orpo_loss( + self, batch: PreferenceOptimizationBatch, loss: Tensor + ) -> None: + """Update the ORPO loss metric. + + :param batch: + The batch processed by the model. + :param loss: + The ORPO loss of ``batch``. + """ + self.orpo_loss.update( + loss / batch.chosen.batch_size, weight=batch.chosen.batch_size + ) + + +@dataclass(kw_only=True) +class OrpoConfig: + """Holds the ORPO configuration of a language model preference-finetuning task.""" + + # Hyperparameters + orpo_lambda: float = 1.0 + """The coefficient of the odds-ratio component of ORPO loss""" + + nll_scale: float = 1.0 + """The coefficient of the NLL component of ORPO loss.""" + + +@preference_unit_factory("orpo") +def create_orpo_unit( + config: OrpoConfig, model: Module, root_gang: Gang, gangs: Mapping[str, Gang] +) -> OrpoFinetuneUnit: + dp_gang = gangs["dp"] # data + + return OrpoFinetuneUnit(model, dp_gang, config.orpo_lambda, config.nll_scale) diff --git a/src/fairseq2/recipes/lm/preference_finetune/recipe.py b/src/fairseq2/recipes/lm/preference_finetune/recipe.py new file mode 100644 index 000000000..12ad37791 --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/recipe.py @@ -0,0 +1,505 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +import torch +import torch.distributed + +from fairseq2.assets import AssetNotFoundError, default_asset_store +from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager +from fairseq2.config_registry import ConfigRegistry +from fairseq2.data.text import load_text_tokenizer +from fairseq2.datasets import Batching, LengthBatching, StaticBatching +from fairseq2.datasets.preference import ( + GenericPreferenceOptimizationDataset, + PreferenceOptimizationBatch, + PreferenceReadOptions, + load_preference_optimization_dataset, +) +from fairseq2.logging import get_log_writer +from fairseq2.models import load_model +from fairseq2.models.decoder import DecoderModel +from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing +from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa +from fairseq2.optim import AdamWConfig, create_optimizer +from fairseq2.optim.lr_scheduler import CosineAnnealingLRConfig, create_lr_scheduler +from fairseq2.recipes.lm.preference_finetune.dpo import DpoConfig +from fairseq2.recipes.lm.preference_finetune.utils import preference_unit_factories +from fairseq2.recipes.trainer import Trainer +from fairseq2.recipes.utils.asset import ( + AssetReference, + asset_as_path, + retrieve_asset_card, +) +from fairseq2.recipes.utils.log import log_model +from fairseq2.recipes.utils.setup import compile_model, setup_gangs, to_data_parallel +from fairseq2.typing import CPU, META, DataType +from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed + +log = get_log_writer(__name__) + + +@dataclass(kw_only=True) +class PreferenceFinetuneConfig: + """Holds the configuration of a language model preference-finetuning task.""" + + # Data + dataset: AssetReference = "gsm8k_dpo_data" # TODO: change! + """The name, path, or path to the asset card of the preference optimization dataset.""" + + min_seq_len: int = 1 + """The minimum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Shorter sequences will be dropped.""" + + max_seq_len: int = 8192 + """The maximum sum of ``src + tgt_chosen`` and ``src + tgt_rejected``. + Longer sequences will be dropped.""" + + max_num_tokens: int = 8192 * 2 + """The maximum number of total `src`, `tgt_chosen`, and `tgt_rejected` tokens per batch.""" + + batch_size: int | None = None + """If not ``None``, ignores `max_num_tokens` and each batch will have `batch_size` examples.""" + + example_shuffle_window: int = 10_000 + """The size of the sliding window for shuffling examples.""" + + batch_shuffle_window: int = 1000 + """The size of the sliding window for shuffling batches.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + mask_source_tokens: bool = True + """If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens.""" + + src_encode_mode: str = "prompt" + """The encode mode for the prompt, determines what special tokens to add.""" + + tgt_encode_mode: str = "prompt_response" + """The encode mode for the target, determines what special tokens to add.""" + + # Model + model: AssetReference = "llama3_1_8b_instruct" + """The name or path to the asset card of the language model to finetune.""" + + model_config: Any = None + """ + The model configuration overrides. The provided values must be compatible + with the checkpoint; otherwise, the model will fail to load. + """ + + dtype: DataType = torch.bfloat16 + """The data type of the model.""" + + mixed_precision: Literal["none", "static", "dynamic"] = "static" + """ + If 'none', the whole training will be run in `dtype`. If 'static', forward + and backward passes will be run in `dtype`, but the optimizer step will be + run in full precision. If 'dynamic', forward and backward passes will be run + with `torch.amp` in `dtype`, but the optimizer step will be run in full + precision. + """ + + data_parallelism: Literal["ddp", "fsdp"] = "fsdp" + """The data parallelism API to use.""" + + fsdp_wrap_granularity: Literal["layer", "stack", "model"] = "layer" + """The granularity at which to wrap the model.""" + + fsdp_reshard_after_forward: bool = True + """If ``True``, reshards the parameters only after the backward pass.""" + + tensor_parallel_size: int = 1 + """The size of tensor parallelism.""" + + activation_checkpointing: bool = True + """If ``True``, uses layer-wise activation checkpointing.""" + + torch_compile: bool = False + """If ``True``, applies ``torch.compile()`` to the decoder. (experimental)""" + + # Criterion + criterion: str = "dpo" + """The preference optimization criterion.""" + + criterion_config: Any = field(default_factory=lambda: DpoConfig()) + """The configuration of the preference optimization criterion.""" + + # Optimizer, LR, and Loss + optimizer: str = "adamw" + """The optimizer.""" + + optimizer_config: Any = field( + default_factory=lambda: AdamWConfig( + lr=5.5e-06, betas=(0.9, 0.95), weight_decay=0.1 + ) + ) + """The configuration of the optimizer.""" + + lr_scheduler: str = "cosine-annealing" + """The learning rate scheduler.""" + + lr_scheduler_config: Any = field( + default_factory=lambda: CosineAnnealingLRConfig(final_lr_scale=0.2) + ) + """The configuration of the learning rate scheduler.""" + + gradient_accumulation: int = 1 + """The number of steps to accumulate gradients before an optimizer update.""" + + max_gradient_norm: float | None = None + """The maximum gradient norm. If ``None``, no clipping will be applied.""" + + fp16_loss_scale: tuple[float, float] = (128.0, 0.0001) + """The initial and minimum loss scale for fp16 training.""" + + # Regime + max_num_steps: int = 5000 + """The maximum number of steps to train for.""" + + max_num_data_epochs: int | None = None + """The maximum number of data epochs to train for.""" + + checkpoint_every_n_steps: int = 1000 + """The step interval at which to checkpoint.""" + + checkpoint_every_n_data_epochs: int | None = None + """The data epoch interval at which to checkpoint.""" + + keep_last_n_checkpoints: int | None = 1 + """The number of checkpoints to keep. If ``None``, none will be deleted.""" + + keep_last_n_models: int | None = None + """The number of checkpoint models to keep.""" + + publish_metrics_every_n_steps: int = 10 + """The step interval at which to publish training metrics.""" + + publish_metrics_every_n_data_epochs: int | None = None + """The data epoch interval at which to publish training metrics.""" + + # Checkpoint + resume_checkpoint_dir: Path | None = None + """If not ``None``, adds the specified path to the default asset store.""" + + # Misc + seed: int = 2 + """The random number generator seed to use.""" + + profile: tuple[int, int] | None = None + """The number of steps that the PyTorch profiler should skip and then record.""" + + monitored_gang: bool = False + """If ``True``, puts a monitored barrier before every collective call.""" + + anomaly_detection: bool = False + """If ``True``, turns on anomaly detection feature in ``torch.autograd``.""" + + wandb_project: str | None = None + """If not ``None``, sets the project name for W&B logging.""" + + wandb_run_name: str | None = None + """If not ``None``, sets the run name for W&B logging. If None, then W&B creates a random name.""" + + +preference_finetune_presets = ConfigRegistry[PreferenceFinetuneConfig]() + +preference_finetune_preset = preference_finetune_presets.decorator + + +@dataclass(kw_only=True) +class DropoutConfig: + dropout_p: float = 0.0 + + +@preference_finetune_preset("llama3_1_instruct") +def _llama3_1_instruct() -> PreferenceFinetuneConfig: + config = PreferenceFinetuneConfig() + config.model_config = DropoutConfig() + return config + + +@preference_finetune_preset("llama3_1_instruct_constant_lr") +def _llama3_1_instruct_constant_lr() -> PreferenceFinetuneConfig: + config = _llama3_1_instruct() + # setting up final lr to be the optmiizer base lr, lr_mul is 1.0 by default + config.lr_scheduler_config.final_lr = config.optimizer_config.lr + return config + + +@preference_finetune_preset("llama3_1_instruct_lr_anneal_0") +def _llama3_1_instruct_lr_anneal_0() -> PreferenceFinetuneConfig: + config = _llama3_1_instruct() + # setting up final lr to be 0.0 at the end of the cycle + config.lr_scheduler_config.final_lr = 0.0 + return config + + +@preference_finetune_preset("llama3_1_70b_instruct") +def _llama3_1_70b_instruct() -> PreferenceFinetuneConfig: + config = _llama3_1_instruct() + + config.model = "llama3_1_70b_instruct" + config.tensor_parallel_size = 8 + config.criterion_config.reference_model = "llama3_1_70b_instruct" + config.criterion_config.reference_tensor_parallel_size = 8 + + return config + + +def load_preference_finetuner( + config: PreferenceFinetuneConfig, output_dir: Path +) -> Trainer[PreferenceOptimizationBatch]: + """Load a :class:`Trainer` for language model preference optimization-finetuning.""" + wall_watch = Stopwatch(start=True) + + root_gang, gangs = setup_gangs( + log, tp_size=config.tensor_parallel_size, monitored=config.monitored_gang + ) + + dp_gang = gangs["dp"] # data + tp_gang = gangs["tp"] # tensor + + checkpoint_manager = FileCheckpointManager( + output_dir.joinpath("checkpoints"), root_gang, dp_gang=dp_gang, tp_gang=tp_gang + ) + + if config.resume_checkpoint_dir is not None: + default_asset_store.metadata_providers.append( + CheckpointModelMetadataProvider(config.resume_checkpoint_dir) + ) + + # Load the tokenizer. + model_card = retrieve_asset_card(config.model) + + log.info("Loading {} tokenizer.", model_card.name) + + tokenizer = load_text_tokenizer(model_card) + + log.info("Tokenizer loaded.") + + # Load the dataset. + try: + dataset_card = retrieve_asset_card(config.dataset) + except AssetNotFoundError: + dataset_card = None + + if dataset_card is not None: + log.info("Loading {} preference optimization dataset.", dataset_card.name) + + dataset = load_preference_optimization_dataset(dataset_card) + + log.info("Dataset loaded.") + else: + dataset_path = asset_as_path(config.dataset) + + dataset = GenericPreferenceOptimizationDataset.from_path(dataset_path) + + seed = config.seed + + # Load the model + manual_seed(seed, CPU, root_gang.device) + + seed += 1 + + init_device = META + + dtype = config.dtype if config.mixed_precision == "none" else torch.float32 + + has_checkpoint = checkpoint_manager.has_checkpoint() + + if has_checkpoint: + try: + model = load_model( + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, + ) + except ValueError as ex: + raise ValueError( + "The model cannot be initialized. See nested exception for details." + ) from ex + # If we don't have a checkpoint, load the pretrained model on rank 0 and + # broadcast it to the gang. + else: + log.info("Loading {} model on data parallel rank 0 (per shard).", model_card.name) # fmt: skip + + if dp_gang.rank == 0: + init_device = root_gang.device + + model = load_model( + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, + ) + + root_gang.barrier() + + log.info("Model loaded on data parallel rank 0.") + + if not isinstance(model, DecoderModel): + raise ValueError( + f"The model must be of type `{DecoderModel}`, but is of type `{type(model)}` instead." + ) + + checkpoint_manager.save_model_metadata(base_asset=model_card.name) + + mp_dtype = config.dtype if config.mixed_precision == "static" else None + + dp_model = to_data_parallel( + model, + dp_gang, + config.data_parallelism, + log, + fsdp_broadcast_state=not has_checkpoint, + fsdp_reshard_after_forward=config.fsdp_reshard_after_forward, + fsdp_mixed_precision_dtype=mp_dtype, + fsdp_fp32_reduce=True, + fsdp_wrap_granularity=config.fsdp_wrap_granularity, + ) + + if config.activation_checkpointing: + use_layerwise_activation_checkpointing(dp_model) + + if config.torch_compile: + model.decoder = compile_model(model.decoder, log) + + # TODO(balioglu): investigate! + # The memory efficient SDPA implementation in PyTorch is not stable when + # used with padded inputs. + enable_memory_efficient_torch_sdpa(dp_model, False) + + log_model(dp_model, log, rank=root_gang.rank) + + # Initialize the train unit. + try: + unit_factory = preference_unit_factories.get( + config.criterion, config.criterion_config + ) + + unit = unit_factory(dp_model, root_gang, gangs) + except ValueError as ex: + raise ValueError( + "The criterion cannot be initialized. See nested exception for details." + ) from ex + + # Initialize the data reader. + batching: Batching + + if config.batch_size is not None: + batching = StaticBatching(config.batch_size) + else: + batching = LengthBatching(config.max_num_tokens) + + options = PreferenceReadOptions( + example_shuffle_window=config.example_shuffle_window, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + mask_source_tokens=config.mask_source_tokens, + source_encode_mode=config.src_encode_mode, + target_encode_mode=config.tgt_encode_mode, + seed=config.seed, + ) + + try: + data_reader = dataset.create_reader( + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + batching, + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + + seed += 1 + + # Initialize the optimizer. + try: + optimizer = create_optimizer( + config.optimizer, dp_model, config.optimizer_config + ) + except ValueError as ex: + raise ValueError( + "The optimizer cannot be created. See nested exception for details." + ) from ex + + # Initialize the learning rate scheduler. + try: + lr_scheduler = create_lr_scheduler( + config.lr_scheduler, + optimizer, + config.lr_scheduler_config, + max_num_steps=config.max_num_steps, + ) + except ValueError as ex: + raise ValueError( + "The learning rate scheduler cannot be created. See nested exception for details." + ) from ex + + # TODO: Fix once we support static mixed precision on one device. + if config.mixed_precision == "static": + amp = root_gang.size == 1 or config.data_parallelism != "fsdp" + else: + amp = config.mixed_precision == "dynamic" + + if config.wandb_project is not None: + if config.wandb_run_name is None: + raise ValueError( + "`wandb_run_name` must be specified when `wandb_project` is set." + ) + + wandb_dir = output_dir.joinpath("wandb") + + wandb_options = (wandb_dir, config.wandb_project, config.wandb_run_name) + else: + wandb_options = None + + # Initialize the trainer. + return Trainer[PreferenceOptimizationBatch]( + unit=unit, + data_reader=data_reader, + root_gang=root_gang, + dp_gang=dp_gang, + tp_gang=tp_gang, + dtype=config.dtype, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + fp16_loss_scale=config.fp16_loss_scale, + max_gradient_norm=config.max_gradient_norm, + amp=amp, + max_num_steps=config.max_num_steps, + max_num_data_epochs=config.max_num_data_epochs, + checkpoint_manager=checkpoint_manager, + checkpoint_every_n_steps=config.checkpoint_every_n_steps, + checkpoint_every_n_data_epochs=config.checkpoint_every_n_data_epochs, + keep_last_n_checkpoints=config.keep_last_n_checkpoints, + keep_last_n_models=config.keep_last_n_models, + tb_dir=output_dir.joinpath("tb"), + metrics_dir=output_dir.joinpath("metrics"), + wandb_options=wandb_options, + publish_metrics_every_n_steps=config.publish_metrics_every_n_steps, + publish_metrics_every_n_data_epochs=config.publish_metrics_every_n_data_epochs, + profile=config.profile, + anomaly_detection=config.anomaly_detection, + seed=config.seed, + wall_watch=wall_watch, + ) diff --git a/src/fairseq2/recipes/lm/preference_finetune/simpo.py b/src/fairseq2/recipes/lm/preference_finetune/simpo.py new file mode 100644 index 000000000..cb305c8d8 --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/simpo.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, cast, final + +import torch +import torch.distributed +from torch import Tensor +from torch.nn import Module +from torcheval.metrics import Mean +from typing_extensions import override + +from fairseq2.datasets.preference import PreferenceOptimizationBatch +from fairseq2.gang import Gang +from fairseq2.logging import get_log_writer +from fairseq2.metrics.recorder import format_as_float, register_metric_formatter +from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input +from fairseq2.recipes.lm.preference_finetune.utils import ( + PreferenceFinetuneMetricBag, + _gather_lprobs_avg, + preference_unit_factory, +) +from fairseq2.recipes.trainer import AbstractTrainUnit + +log = get_log_writer(__name__) + + +@final +class SimPOFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): + """Represents the language model SimPO-finetuning unit. Paper: https://arxiv.org/abs/2405.14734.""" + + _beta: float + _gamma: float + _nll_scale: float + _metric_bag: SimPOFinetuneMetricBag + + def __init__( + self, + model: Module, + gang: Gang, + beta: float = 0.1, + gamma: float = 0.5, + nll_scale: float = 1.0, + ) -> None: + super().__init__(model) + + self._beta = beta + self._gamma = gamma + self._nll_scale = nll_scale + + self._metric_bag = SimPOFinetuneMetricBag(gang) + + @override + def __call__(self, batch: PreferenceOptimizationBatch) -> tuple[Tensor, int]: + chosen_batch = batch.chosen + chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) + rejected_batch = batch.rejected + rejected_input_batch, rejected_target_batch = as_auto_regressive_input( + rejected_batch + ) + + chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) + rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) + + chosen_logps, average_chosen_logps = _gather_lprobs_avg( + chosen_output, chosen_target_batch + ) + rejected_logps, average_rejected_logps = _gather_lprobs_avg( + rejected_output, rejected_target_batch + ) + + simpo_loss = self._compute_simpo_loss( + average_chosen_logps, average_rejected_logps + ) + + nll_loss = chosen_output.compute_loss( + chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask + ) + + self._metric_bag.update_simpo_loss(batch, simpo_loss) + + self._metric_bag.update_nll_loss(chosen_batch, nll_loss) + + self._metric_bag.update_sequence_lengths(batch) + + self._metric_bag.update_logps(batch, chosen_logps, rejected_logps) + + self._metric_bag.update_batch_metrics(chosen_batch) + + loss = ( + simpo_loss + + self._nll_scale + * nll_loss + * chosen_target_batch.batch_size + / chosen_target_batch.num_target_elements() + ) # nll normalization applied locally per-rank + + return loss, chosen_target_batch.batch_size + + def _compute_simpo_loss( + self, average_chosen_logps: Tensor, average_rejected_logps: Tensor + ) -> Tensor: + simpo_loss = -torch.nn.functional.logsigmoid( + self._beta * (average_chosen_logps - average_rejected_logps) - self._gamma + ) + return simpo_loss.sum() + + @override + def set_step_nr(self, step_nr: int) -> None: + self._step_nr = step_nr + + @property + @override + def metric_bag(self) -> SimPOFinetuneMetricBag: + return self._metric_bag + + +register_metric_formatter("simpo_loss", "SimPO Loss", 0, format_as_float) + + +class SimPOFinetuneMetricBag(PreferenceFinetuneMetricBag): + """Holds the metrics of a SimPO preference finetuning task.""" + + simpo_loss: Mean + + def __init__(self, gang: Gang) -> None: + super().__init__(gang) + + self.register_metric("simpo_loss", Mean(device=gang.device), persistent=False) + + @torch.inference_mode() + def update_simpo_loss( + self, batch: PreferenceOptimizationBatch, loss: Tensor + ) -> None: + """Update the SimPO loss metric. + + :param batch: + The batch processed by the model. + :param loss: + The SimPO loss of ``batch``. + """ + self.simpo_loss.update( + loss / batch.chosen.batch_size, weight=batch.chosen.batch_size + ) + + +@dataclass(kw_only=True) +class SimPOConfig: + """Holds the SimPO configuration of a language model preference-finetuning task.""" + + beta: float = 1 + """The coefficient of KL-divergence regularization.""" + + gamma: float = 0.5 + """The target reward margin between positive and negative completions.""" + + nll_scale: float = 0.0 + """The coefficient of NLL loss added to the SimPO loss.""" + + +@preference_unit_factory("simpo") +def create_simpo_unit( + config: SimPOConfig, model: Module, root_gang: Gang, gangs: Mapping[str, Gang] +) -> SimPOFinetuneUnit: + dp_gang = gangs["dp"] # data + + return SimPOFinetuneUnit( + model, dp_gang, config.beta, config.gamma, config.nll_scale + ) diff --git a/src/fairseq2/recipes/lm/preference_finetune/utils.py b/src/fairseq2/recipes/lm/preference_finetune/utils.py new file mode 100644 index 000000000..010d40dfc --- /dev/null +++ b/src/fairseq2/recipes/lm/preference_finetune/utils.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Mapping + +import torch +from torch import Tensor +from torch.nn import Module +from torcheval.metrics import Mean + +from fairseq2.datasets.preference import PreferenceOptimizationBatch +from fairseq2.factory_registry import ConfigBoundFactoryRegistry +from fairseq2.gang import Gang +from fairseq2.logging import LogWriter +from fairseq2.metrics.recorder import format_as_float, register_metric_formatter +from fairseq2.models import load_model +from fairseq2.models.sequence import SequenceBatch, SequenceModelOutput +from fairseq2.nn.utils.module import freeze_parameters +from fairseq2.recipes.common_metrics import SequenceMetricBag +from fairseq2.recipes.trainer import TrainUnit +from fairseq2.recipes.utils.asset import AssetReference, retrieve_asset_card +from fairseq2.recipes.utils.setup import broadcast_model +from fairseq2.typing import META, DataType + + +def _load_reference_model( + model_name_or_card: AssetReference, + dtype: DataType, + root_gang: Gang, + gangs: Mapping[str, Gang], + tensor_parallel_size: int, + log: LogWriter, +) -> Module: + dp_gang = gangs["dp"] + + card = retrieve_asset_card(model_name_or_card) + + log.info("Loading {} reference model on data parallel rank 0 (per shard).", card.name) # fmt: skip + + if dp_gang.rank == 0: + init_device = root_gang.device + else: + init_device = META + + # TODO: figure out how to load the reference model onto its own gangs + model = load_model(card, gangs=gangs, device=init_device, dtype=dtype) + + root_gang.barrier() + + log.info("Reference model loaded on data parallel rank 0.") + + model.eval() + + freeze_parameters(model) + + # Distribute the model to all processes in the gang. + if dp_gang.size != 1: + broadcast_model(model, dp_gang, log) + + return model + + +def _gather_lprobs(output: SequenceModelOutput, target: SequenceBatch) -> Tensor: + logprobs = torch.log_softmax(output.logits, dim=-1) + chosen_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) + chosen_logps = (chosen_logps * target.target_mask).sum(dim=-1) # [Batch, 1] + + return chosen_logps + + +def _gather_lprobs_avg( + output: SequenceModelOutput, target: SequenceBatch +) -> tuple[Tensor, Tensor]: + logprobs = torch.log_softmax(output.logits, dim=-1) + per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) + total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1] + assert target.target_mask is not None + average_logps = total_logps / target.target_mask.sum(-1) + + return total_logps, average_logps + + +register_metric_formatter( + "chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float +) +register_metric_formatter( + "rejected_logps", "Rejected Sequence Log Probabilities", 50, format_as_float +) +register_metric_formatter( + "chosen_lengths", "Chosen Sequence Length", 70, format_as_float +) +register_metric_formatter( + "rejected_lengths", "Rejected Sequence Length", 70, format_as_float +) + + +class PreferenceFinetuneMetricBag(SequenceMetricBag): + """Holds the metrics of a sequence model preference finetuning task.""" + + chosen_logps: Mean + rejected_logps: Mean + chosen_lengths: Mean + rejected_lengths: Mean + + def __init__(self, gang: Gang) -> None: + super().__init__(gang) + + self.register_metric("chosen_logps", Mean(device=gang.device), persistent=False) + self.register_metric( + "rejected_logps", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "chosen_lengths", Mean(device=gang.device), persistent=False + ) + self.register_metric( + "rejected_lengths", Mean(device=gang.device), persistent=False + ) + + @torch.inference_mode() + def update_logps( + self, + batch: PreferenceOptimizationBatch, + chosen_logps: Tensor, + rejected_logps: Tensor, + ) -> None: + """Update the Chosen Sequence Log Probabilities and Rejected Sequence Log Probabilities metrics. + + :param batch: + The batch processed by the model. + :param chosen_logps: + The log probabilities for each sequence in ``batch.chosen``. + :param rejected_logps: + The log probabilities for each sequence in ``batch.rejected``. + """ + self.chosen_logps.update( + chosen_logps.sum() / batch.chosen.batch_size, weight=batch.chosen.batch_size + ) + self.rejected_logps.update( + rejected_logps.sum() / batch.rejected.batch_size, + weight=batch.rejected.batch_size, + ) + + @torch.inference_mode() + def update_sequence_lengths( + self, + batch: PreferenceOptimizationBatch, + ) -> None: + """Update the Chosen Sequence Length and Rejected Sequence Length metrics. + + :param batch: + The batch processed by the model. + """ + self.chosen_lengths.update( + Tensor([batch.chosen.num_target_elements() / batch.chosen.batch_size]), + weight=batch.chosen.batch_size, + ) + self.rejected_lengths.update( + Tensor([batch.rejected.num_target_elements() / batch.rejected.batch_size]), + weight=batch.rejected.batch_size, + ) + + +preference_unit_factories = ConfigBoundFactoryRegistry[ + [Module, Gang, Mapping[str, Gang]], TrainUnit[PreferenceOptimizationBatch] +]() + +preference_unit_factory = preference_unit_factories.decorator diff --git a/src/fairseq2/recipes/lm/preference_units/dpo_unit.py b/src/fairseq2/recipes/lm/preference_units/dpo_unit.py deleted file mode 100644 index fb4964447..000000000 --- a/src/fairseq2/recipes/lm/preference_units/dpo_unit.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path -from typing import Tuple, Union, cast, final - -import torch -import torch.distributed -from torch import Tensor -from torch.nn import Module -from torcheval.metrics import Mean - -from fairseq2.datasets.preference import PreferenceOptimizationBatch -from fairseq2.gang import Gang, get_rank -from fairseq2.logging import get_log_writer -from fairseq2.metrics.recorder import format_as_float, register_metric_formatter -from fairseq2.models.sequence import ( - SequenceBatch, - SequenceModelOutput, - as_auto_regressive_input, -) -from fairseq2.recipes.common_metrics import SequenceMetricBag -from fairseq2.recipes.trainer import AbstractTrainUnit -from fairseq2.typing import DataType, override - -log = get_log_writer(__name__) - - -@dataclass -class DpoFinetuneConfig: - """Holds the DPO-finetuning configuration of a language model.""" - - # Hyperparameters - dpo_beta: float = 0.1 - """The coefficient of regularization towards the reference model.""" - - nll_scale: float = 0.0 - """The coefficient of NLL loss added to the DPO loss.""" - - # Reference Model - reference_model: Union[str, Path] = "llama3_8b_instruct" - """The name or path to the asset card of the reference model to use.""" - - reference_dtype: DataType = torch.bfloat16 - """The data type of the reference model.""" - - reference_tensor_parallel_size: int = 1 - """The size of tensor parallelism for the reference model.""" - - -@final -class DpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): - """Represents the DPO-finetuning unit of a language model.""" - - _reference_model: Module - _beta: float - _nll_scale: float - _metric_bag: DpoFinetuneMetricBag - - def __init__( - self, - model: Module, - reference_model: Module, - gang: Gang, - beta: float = 0.1, - nll_scale: float = 1.0, - ) -> None: - super().__init__(model) - - self._reference_model = reference_model - self._beta = beta - self._nll_scale = nll_scale - - self._metric_bag = DpoFinetuneMetricBag(gang) - self._gang = gang - - @override - def __call__(self, batch: PreferenceOptimizationBatch) -> Tuple[Tensor, int]: - chosen_batch = batch.chosen - chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) - rejected_batch = batch.rejected - rejected_input_batch, rejected_target_batch = as_auto_regressive_input( - rejected_batch - ) - - chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) - rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) - - chosen_logps = self._gather_lprobs(chosen_output, chosen_target_batch) - rejected_logps = self._gather_lprobs(rejected_output, rejected_target_batch) - - with torch.no_grad(): - ref_chosen_output = cast( - SequenceModelOutput, self._reference_model(chosen_batch) - ) - ref_rejected_output = cast( - SequenceModelOutput, self._reference_model(rejected_batch) - ) - ref_chosen_logps = self._gather_lprobs( - ref_chosen_output, chosen_target_batch - ) - ref_rejected_logps = self._gather_lprobs( - ref_rejected_output, rejected_target_batch - ) - - _, _, dpo_loss = self._compute_dpo_loss( - chosen_logps, ref_chosen_logps, rejected_logps, ref_rejected_logps - ) - - nll_loss = chosen_output.compute_loss( - chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask - ) - - # adding NLL loss to the total loss for now! - loss = dpo_loss + self._nll_scale * nll_loss - - log.info( - f"Step:{self._step_nr} Rank:{get_rank()} IDs:{[str(idx) for idx in batch.chosen.example['id']]}, DPO loss: {dpo_loss.item()}" - ) - - self._metric_bag.update_nll_loss(chosen_batch, nll_loss) - self._metric_bag.update_dpo_loss(chosen_batch, dpo_loss) - - self._metric_bag.update_batch_metrics(chosen_batch) - - return loss, chosen_target_batch.batch_size - - def _gather_lprobs( - self, output: SequenceModelOutput, target: SequenceBatch - ) -> Tensor: - logprobs = torch.log_softmax(output.logits, dim=-1) - chosen_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) - chosen_logps = (chosen_logps * target.target_mask).sum(dim=-1) # [Batch, 1] - - return chosen_logps - - def _compute_dpo_loss( - self, - chosen_logps: Tensor, - ref_chosen_logps: Tensor, - rejected_logps: Tensor, - ref_rejected_logps: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - logp_ratio_chosen = self._beta * (chosen_logps - ref_chosen_logps) - logp_ratio_rejected = self._beta * (rejected_logps - ref_rejected_logps) - dpo_loss = -torch.nn.functional.logsigmoid( - logp_ratio_chosen - logp_ratio_rejected - ) - return logp_ratio_chosen, logp_ratio_rejected, dpo_loss.sum() - - @property - @override - def metric_bag(self) -> DpoFinetuneMetricBag: - return self._metric_bag - - def set_step_nr(self, step_nr: int) -> None: - """Set the current training step number.""" - self._step_nr = step_nr - - -register_metric_formatter("dpo_loss", "DPO Loss", 0, format_as_float) - - -class DpoFinetuneMetricBag(SequenceMetricBag): - _dpo_loss: Mean - - def __init__(self, gang: Gang) -> None: - super().__init__(gang) - - self.register_metric("_dpo_loss", Mean(device=gang.device), persistent=False) - - @torch.inference_mode() - def update_dpo_loss(self, batch: SequenceBatch, loss: Tensor) -> None: - batch_size = torch.tensor(batch.batch_size) - - normalized_loss = loss.cpu() / batch_size - - self._dpo_loss.update(normalized_loss, weight=batch_size) diff --git a/src/fairseq2/recipes/lm/preference_units/simpo_unit.py b/src/fairseq2/recipes/lm/preference_units/simpo_unit.py deleted file mode 100644 index d65a656f5..000000000 --- a/src/fairseq2/recipes/lm/preference_units/simpo_unit.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Tuple, cast, final - -import torch -import torch.distributed -from torch import Tensor -from torch.nn import Module -from torcheval.metrics import Mean - -from fairseq2.datasets.preference import PreferenceOptimizationBatch -from fairseq2.gang import Gang, get_rank -from fairseq2.logging import get_log_writer -from fairseq2.metrics.recorder import format_as_float, register_metric_formatter -from fairseq2.models.sequence import ( - SequenceBatch, - SequenceModelOutput, - as_auto_regressive_input, -) -from fairseq2.recipes.common_metrics import SequenceMetricBag -from fairseq2.recipes.trainer import AbstractTrainUnit -from fairseq2.typing import override - -log = get_log_writer(__name__) - - -@dataclass -class SimpoFinetuneConfig: - """Holds the SimPO-finetuning configuration of a language model.""" - - # Hyperparameters - simpo_beta: float = 1 - """The coefficient of KL-divergence regularization.""" - - simpo_gamma: float = 0.5 - """Target reward margin between positive and negative completions.""" - - nll_scale: float = 0.0 - """The coefficient of NLL loss added to the SimPO loss.""" - - -@final -class SimpoFinetuneUnit(AbstractTrainUnit[PreferenceOptimizationBatch]): - """Represents the DPO-finetuning unit of a language model.""" - - _beta: float - _gamma: float - _nll_scale: float - _metric_bag: SimpoFinetuneMetricBag - - def __init__( - self, - model: Module, - gang: Gang, - beta: float = 0.1, - gamma: float = 0.5, - nll_scale: float = 1.0, - ) -> None: - super().__init__(model) - - self._beta = beta - self._gamma = gamma - self._nll_scale = nll_scale - - self._metric_bag = SimpoFinetuneMetricBag(gang) - self._gang = gang - - @override - def __call__(self, batch: PreferenceOptimizationBatch) -> Tuple[Tensor, int]: - chosen_batch = batch.chosen - chosen_input_batch, chosen_target_batch = as_auto_regressive_input(chosen_batch) - rejected_batch = batch.rejected - rejected_input_batch, rejected_target_batch = as_auto_regressive_input( - rejected_batch - ) - - chosen_output = cast(SequenceModelOutput, self._model(chosen_input_batch)) - rejected_output = cast(SequenceModelOutput, self._model(rejected_input_batch)) - - chosen_logps, average_chosen_logps = self._gather_lprobs( - chosen_output, chosen_target_batch - ) - rejected_logps, average_rejected_logps = self._gather_lprobs( - rejected_output, rejected_target_batch - ) - - simpo_loss = self._compute_simpo_loss( - average_chosen_logps, average_rejected_logps - ) - - nll_loss = chosen_output.compute_loss( - chosen_target_batch.seqs, loss_mask=chosen_target_batch.target_mask - ) - - # adding NLL loss to the total loss for now! - loss = simpo_loss + self._nll_scale * nll_loss - - log.info( - f"Step:{self._step_nr} Rank:{get_rank()} IDs:{[str(idx) for idx in batch.chosen.example['id']]}, SimPO loss: {simpo_loss.item()}" - ) - - self._metric_bag.update_nll_loss(chosen_batch, nll_loss) - self._metric_bag.update_simpo_loss(chosen_batch, simpo_loss) - - self._metric_bag.update_batch_metrics(chosen_batch) - - return loss, chosen_target_batch.batch_size - - def _gather_lprobs( - self, output: SequenceModelOutput, target: SequenceBatch - ) -> Tuple[Tensor, Tensor]: - logprobs = torch.log_softmax(output.logits, dim=-1) - per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze( - -1 - ) - total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1] - assert ( - target.target_mask is not None - ) # TODO hacky mypy fix - perhaps use the length of the per_token_logps? - average_logps = total_logps / target.target_mask.sum(-1) - - return total_logps, average_logps - - def _compute_simpo_loss( - self, - average_chosen_logps: Tensor, - average_rejected_logps: Tensor, - ) -> Tensor: - simpo_loss = -torch.nn.functional.logsigmoid( - self._beta * (average_chosen_logps - average_rejected_logps) - self._gamma - ) - return simpo_loss.sum() - - @property - @override - def metric_bag(self) -> SimpoFinetuneMetricBag: - return self._metric_bag - - def set_step_nr(self, step_nr: int) -> None: - """Set the current training step number.""" - self._step_nr = step_nr - - -register_metric_formatter("simpo_loss", "SimPO Loss", 0, format_as_float) - - -class SimpoFinetuneMetricBag(SequenceMetricBag): - _simpo_loss: Mean - - def __init__(self, gang: Gang) -> None: - super().__init__(gang) - - self.register_metric("_simpo_loss", Mean(device=gang.device), persistent=False) - - @torch.inference_mode() - def update_simpo_loss(self, batch: SequenceBatch, loss: Tensor) -> None: - batch_size = torch.tensor(batch.batch_size) - - normalized_loss = loss.cpu() / batch_size - - self._simpo_loss.update(normalized_loss, weight=batch_size) diff --git a/src/fairseq2/recipes/lm/text_generate.py b/src/fairseq2/recipes/lm/text_generate.py index 3c7699bfb..c020281bb 100644 --- a/src/fairseq2/recipes/lm/text_generate.py +++ b/src/fairseq2/recipes/lm/text_generate.py @@ -9,9 +9,10 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Literal, Optional, TextIO, final +from typing import TextIO, final import torch +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider @@ -20,18 +21,11 @@ from fairseq2.datasets import StaticBatching from fairseq2.datasets.instruction import ( GenericInstructionDataset, + InstructionPromptReadOptions, load_instruction_dataset, ) from fairseq2.gang import Gang -from fairseq2.generation import ( - BeamSearchSequenceGenerator, - Sampler, - SamplingSequenceGenerator, - SequenceGenerator, - StandardBeamSearchAlgorithm, - TopKSampler, - TopPSampler, -) +from fairseq2.generation import SamplingConfig, SequenceGenerator, create_seq_generator from fairseq2.logging import get_log_writer from fairseq2.models import load_model from fairseq2.models.decoder import DecoderModel @@ -44,14 +38,15 @@ retrieve_asset_card, ) from fairseq2.recipes.utils.log import log_model -from fairseq2.recipes.utils.setup import broadcast_model, check_model_type, setup_gangs -from fairseq2.typing import META, DataType, override +from fairseq2.recipes.utils.setup import broadcast_model, setup_gangs +from fairseq2.typing import CPU, META, DataClass, DataType from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class TextGenerateConfig: """Holds the configuration of a text generation task.""" @@ -59,6 +54,12 @@ class TextGenerateConfig: dataset: AssetReference = "foo" # TODO: change! """The name, path, or path to the asset card of the instruction dataset.""" + split: str = "default" + """The name of the data split.""" + + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 8192 """The maximum sequence length.""" @@ -72,124 +73,30 @@ class TextGenerateConfig: model: AssetReference = "llama3_8b_instruct" """The name of the model to generate with.""" - checkpoint_dir: Optional[Path] = None + checkpoint_dir: Path | None = None """The checkpoint directory containing models saved by :class:`FileCheckpointManager`.""" dtype: DataType = torch.bfloat16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + tensor_parallel_size: int = 1 """The size of tensor parallelism.""" # Generation - mode: Literal["sampling", "beam_search"] = "sampling" - """The mode of sequence generation.""" - - sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig()) - """The configuration for sampling-based sequence generation.""" + generator: str = "sampling" + """The sequence generator.""" - beam_search: BeamSearchConfig = field(default_factory=lambda: BeamSearchConfig()) - """The configuration for beam search-based sequence generation.""" + generator_config: DataClass | None = field(default_factory=lambda: SamplingConfig()) + """The configuration of the sequence generator.""" # Misc seed: int = 2 """The random number generator seed to use.""" -@dataclass -class SamplingConfig: - """Holds the configuration for sampling-based sequence generation. - - See :class:`SamplingSequenceGenerator` for more info. - """ - - sampler: Literal["top-p", "top-k"] = "top-p" - """The sampling algorithm.""" - - top_p: float = 1.0 - """The cumulative probability threshold for top-p sampling.""" - - top_k = 1 - """The number of top candidates to select from for top-k sampling.""" - - min_gen_len: int = 1 - """The minimum generation length.""" - - max_gen_len: int = 2048 - """The maximum generation length.""" - - max_seq_len: Optional[int] = None - """The maximum sequence length including prompt.""" - - echo_prompt: bool = False - """If ``True``, returns generated sequences with prompts appended.""" - - compute_scores: bool = False - """If ``True``, computes scores of generated sequences.""" - - normalize_scores: bool = True - """If ``True``, normalizes scores by lengths of generated sequences.""" - - temperature: float = 1.0 - """The logit temperature.""" - - unk_penalty: float = 0.0 - """The UNK symbol penalty.""" - - len_penalty: float = 1.0 - """The length penalty.""" - - prefill_chunk_size: Optional[int] = 512 - """The prefill will be performed incrementally by chunks of this size.""" - - decode_capacity_increment: Optional[int] = 16 - """The sequence length capacity will be incremented by multiplies of this value.""" - - -@dataclass -class BeamSearchConfig: - """Holds the configuration for beam search-based sequence generation. - - See :class:`BeamSearchSequenceGenerator` for more info. - """ - - algorithm: Literal["standard"] = "standard" - """The beam search algorithm.""" - - beam_size: int = 5 - """The beam size.""" - - min_gen_len: int = 1 - """The minimum generation length.""" - - max_gen_len: int = 2048 - """The maximum generation length.""" - - max_seq_len: Optional[int] = None - """The maximum sequence length including prompt.""" - - echo_prompt: bool = False - """If ``True``, returns generated sequences with prompts appended.""" - - normalize_scores: bool = True - """If ``True``, normalizes scores by lengths of generated sequences.""" - - temperature: float = 1.0 - """The logit temperature.""" - - unk_penalty: float = 0.0 - """The UNK symbol penalty.""" - - len_penalty: float = 1.0 - """The length penalty.""" - - prefill_chunk_size: Optional[int] = 512 - """The prefill will be performed incrementally by chunks of this size.""" - - decode_capacity_increment: Optional[int] = 16 - """The sequence length capacity will be incremented by multiplies of this value.""" - - text_generate_presets = ConfigRegistry[TextGenerateConfig]() text_generate_preset = text_generate_presets.decorator @@ -229,6 +136,25 @@ def _llama3_70b_instruct() -> TextGenerateConfig: return config +@text_generate_preset("llama3_1_8b_instruct") +def _llama3_1_8b_instruct() -> TextGenerateConfig: + config = _llama3_8b_instruct() + + config.model = "llama3_1_8b_instruct" + + return config + + +@text_generate_preset("llama3_1_70b_instruct") +def _llama3_1_70b_instruct() -> TextGenerateConfig: + config = _llama3_70b_instruct() + + config.model = "llama3_1_70b_instruct" + + return config + + +@torch.inference_mode() def load_text_generator( config: TextGenerateConfig, output_dir: Path ) -> Generator[SequenceBatch]: @@ -245,8 +171,6 @@ def load_text_generator( dp_gang = gangs["dp"] # data tp_gang = gangs["tp"] # tensor - seed = config.seed - model_card = retrieve_asset_card(config.model) # Load the tokenizer. @@ -273,7 +197,13 @@ def load_text_generator( dataset = GenericInstructionDataset.from_path(dataset_path) - # Load the model. + seed = config.seed + + # Load the model + manual_seed(seed, CPU, root_gang.device) + + seed += 1 + log.info("Loading {} model on data parallel rank 0 (per shard).", model_card.name) if dp_gang.rank == 0: @@ -290,7 +220,10 @@ def load_text_generator( "The model cannot be initialized. See nested exception for details." ) from ex - check_model_type(model, DecoderModel) + if not isinstance(model, DecoderModel): + raise ValueError( + f"The model must be of type `{DecoderModel}`, but is of type `{type(model)}` instead." + ) root_gang.barrier() @@ -303,9 +236,14 @@ def load_text_generator( log_model(model, log) # Initialize the sequence generator. - generator = _create_sequence_generator( - model, config.mode, config.beam_search, config.sampling # type: ignore[arg-type] - ) + try: + generator = create_seq_generator( + config.generator, model, config.generator_config + ) + except ValueError as ex: + raise ValueError( + "The sequence generator cannot be created. See nested exception for details." + ) from ex # Initialize the generator unit. if tp_gang.rank == 0: @@ -345,16 +283,25 @@ def load_text_generator( json_output_stream=json_output_fp, ) - data_reader = dataset.create_prompt_reader( - tokenizer, - dp_gang, - config.max_seq_len, - batching=StaticBatching(config.batch_size), - sync_batches=False, - num_prefetch=config.num_prefetch, - seed=seed, + options = InstructionPromptReadOptions( + sync_mode="until_last", num_prefetch=config.num_prefetch ) + try: + data_reader = dataset.create_prompt_reader( + config.split, + tokenizer, + dp_gang, + config.min_seq_len, + config.max_seq_len, + StaticBatching(config.batch_size), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + seed += 1 # Initialize the generator. @@ -364,6 +311,8 @@ def load_text_generator( root_gang=root_gang, dp_gang=dp_gang, tp_gang=tp_gang, + dtype=config.dtype, + amp=config.amp, metrics_dir=output_dir.joinpath("metrics"), seed=seed, wall_watch=wall_watch, @@ -376,8 +325,8 @@ class TextGenerateUnit(AbstractGeneratorUnit[SequenceBatch]): _generator: SequenceGenerator _text_decoder: TextTokenDecoder - _text_output_stream: Optional[TextIO] - _json_output_stream: Optional[TextIO] + _text_output_stream: TextIO | None + _json_output_stream: TextIO | None _metric_bag: SequenceGenerationMetricBag def __init__( @@ -385,8 +334,8 @@ def __init__( generator: SequenceGenerator, tokenizer: TextTokenizer, gang: Gang, - text_output_stream: Optional[TextIO], - json_output_stream: Optional[TextIO], + text_output_stream: TextIO | None, + json_output_stream: TextIO | None, ) -> None: super().__init__(generator.model) @@ -504,77 +453,3 @@ def __call__(self, batch: SequenceBatch) -> None: @override def metric_bag(self) -> SequenceGenerationMetricBag: return self._metric_bag - - -def _create_sequence_generator( - model: DecoderModel, - mode: str, - beam_search_config: BeamSearchConfig, - sampling_config: SamplingConfig, -) -> SequenceGenerator: - if mode == "sampling": - return _create_sampling_generator(model, sampling_config) - - if mode == "beam_search": - return _create_beam_search_generator(model, beam_search_config) - - raise ValueError( - f"`generator_mode` must be 'sampling' or 'beam_search', but is '{mode}' instead." - ) - - -def _create_sampling_generator( - model: DecoderModel, config: SamplingConfig -) -> SamplingSequenceGenerator: - sampler: Sampler - - if config.sampler == "top-p": - sampler = TopPSampler(config.top_p) - elif config.sampler == "top-k": - sampler = TopKSampler(config.top_k) - else: - raise ValueError( - f"`sampling.sampler` must be 'top-p' or 'top-k', but is '{config.sampler}' instead." - ) - - return SamplingSequenceGenerator( - model, - sampler, - min_gen_len=config.min_gen_len, - max_gen_len=config.max_gen_len, - max_seq_len=config.max_seq_len, - echo_prompt=config.echo_prompt, - compute_scores=config.compute_scores, - normalize_scores=config.normalize_scores, - temperature=config.temperature, - unk_penalty=config.unk_penalty, - len_penalty=config.len_penalty, - prefill_chunk_size=config.prefill_chunk_size, - decode_capacity_increment=config.decode_capacity_increment, - ) - - -def _create_beam_search_generator( - model: DecoderModel, config: BeamSearchConfig -) -> BeamSearchSequenceGenerator: - if config.algorithm == "standard": - algorithm = StandardBeamSearchAlgorithm() - else: - raise ValueError( - f"`beam_search.algorithm` must be 'standard', but is '{config.algorithm}' instead." - ) - - return BeamSearchSequenceGenerator( - model, - algorithm=algorithm, - beam_size=config.beam_size, - min_gen_len=config.min_gen_len, - max_gen_len=config.max_gen_len, - echo_prompt=config.echo_prompt, - normalize_scores=config.normalize_scores, - temperature=config.temperature, - unk_penalty=config.unk_penalty, - len_penalty=config.len_penalty, - prefill_chunk_size=config.prefill_chunk_size, - decode_capacity_increment=config.decode_capacity_increment, - ) diff --git a/src/fairseq2/recipes/logging.py b/src/fairseq2/recipes/logging.py index 168e47b5b..60faf3301 100644 --- a/src/fairseq2/recipes/logging.py +++ b/src/fairseq2/recipes/logging.py @@ -9,123 +9,108 @@ import logging import os import time +from abc import ABC, abstractmethod from logging import DEBUG, INFO, FileHandler, Formatter, Handler, NullHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, final from fairseq2n import DOC_MODE from rich.logging import RichHandler +from typing_extensions import override -from fairseq2.console import get_error_console +from fairseq2.error import SetupError +from fairseq2.gang import get_rank +from fairseq2.recipes.utils.rich import get_error_console def setup_basic_logging(*, debug: bool = False, utc_time: bool = False) -> None: - """Set up logging for a command line program. - - :param debug: - If ``True``, sets the log level to ``DEBUG``; otherwise, to ``INFO``. - :param utc_time: - If ``True``, logs dates and times in UTC. - """ - from fairseq2.gang import get_rank # Avoid circular import. - rank = get_rank() - _do_setup_logging(rank, debug, utc_time) + _setup_core_logging(rank, debug, utc_time) if rank != 0: getLogger().addHandler(NullHandler()) -def setup_logging( - log_file: Path, *, debug: bool = False, utc_time: bool = False, force: bool = False -) -> None: - """Set up logging for a distributed job. +class LoggingInitializer(ABC): + @abstractmethod + def initialize( + self, log_file: Path, *, debug: bool = False, utc_time: bool = False + ) -> None: + ... - :param log_file: - The file to which logs will be written. Must have a 'rank' replacement - field; for example '/path/to/train_{rank}.log'. - :param debug: - If ``True``, sets the log level to ``DEBUG``; otherwise, to ``INFO``. - :param utc_time: - If ``True``, logs dates and times in UTC. - :param force: - If ``True``, overwrites existing ATen and NCCL log configurations. - """ - from fairseq2.gang import get_rank # Avoid circular import. - rank = get_rank() +@final +class DistributedLoggingInitializer(LoggingInitializer): + @override + def initialize( + self, log_file: Path, *, debug: bool = False, utc_time: bool = False + ) -> None: + rank = get_rank() - filename = log_file.name.format(rank=rank) + filename = log_file.name.format(rank=rank) - if filename == log_file.name: - raise ValueError( - f"`log_file` must contain a 'rank' replacement field (i.e. {{rank}}) in its filename, but is '{log_file}' instead." - ) + if filename == log_file.name: + raise ValueError( + f"`log_file` must have a 'rank' replacement field (i.e. {{rank}}) in its filename, but is '{log_file}' instead." + ) - log_file = log_file.with_name(filename) + log_file = log_file.with_name(filename) - try: - log_file.parent.mkdir(parents=True, exist_ok=True) - except OSError as ex: - raise RuntimeError( - f"The log directory ({log_file.parent}) cannot be created. See nested exception for details." - ) from ex + try: + log_file.parent.mkdir(parents=True, exist_ok=True) + except OSError as ex: + raise SetupError( + f"The '{log_file}' log file cannot be created. See the nested exception for details." + ) from ex - _do_setup_logging(rank, debug, utc_time) + _setup_core_logging(rank, debug, utc_time) - handler = FileHandler(log_file) + handler = FileHandler(log_file) - fmt = Formatter(f"[Rank {rank}] %(asctime)s %(levelname)s %(name)s - %(message)s") + handler.setFormatter( + Formatter(f"[Rank {rank}] %(asctime)s %(levelname)s %(name)s - %(message)s") + ) - handler.setFormatter(fmt) + getLogger().addHandler(handler) - getLogger().addHandler(handler) + _setup_aten_logging(log_file) + _setup_nccl_logging(log_file) - _setup_aten_logging(log_file, force) - _setup_nccl_logging(log_file, force) +def _setup_core_logging(rank: int, debug: bool = False, utc_time: bool = False) -> None: + level = DEBUG if debug else INFO -def _do_setup_logging(rank: int, debug: bool = False, utc_time: bool = False) -> None: if utc_time: Formatter.converter = time.gmtime - handlers: List[Handler] = [] + handlers: list[Handler] = [] if rank == 0: console = get_error_console() handler = RichHandler(console=console, show_path=False, keywords=[]) - fmt = Formatter("%(name)s - %(message)s") - - handler.setFormatter(fmt) + handler.setFormatter(Formatter("%(name)s - %(message)s")) handlers.append(handler) datefmt = "%Y-%m-%d %H:%M:%S" - logging.basicConfig( - level=DEBUG if debug else INFO, handlers=handlers, datefmt=datefmt, force=True - ) + logging.basicConfig(level=level, handlers=handlers, datefmt=datefmt, force=True) -def _setup_aten_logging(log_file: Path, force: bool) -> None: - if "TORCH_CPP_LOG_LEVEL" in os.environ and not force: +def _setup_aten_logging(log_file: Path) -> None: + if "TORCH_CPP_LOG_LEVEL" in os.environ: return aten_log_file = log_file.parent.joinpath("aten", log_file.name) - try: - aten_log_file.parent.mkdir(parents=True, exist_ok=True) - except OSError as ex: - raise RuntimeError( - f"The ATen log directory ({aten_log_file.parent}) cannot be created. See nested exception for details." - ) from ex + aten_log_file.parent.mkdir(parents=True, exist_ok=True) _enable_aten_logging(aten_log_file) - # This variable has no effect at this point. We set it for completeness. + # This variable has no effect at this point; set for completeness. os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO" @@ -138,18 +123,13 @@ def _enable_aten_logging(log_file: Path) -> Path: from fairseq2n.bindings import _enable_aten_logging -def _setup_nccl_logging(log_file: Path, force: bool) -> None: - if "NCCL_DEBUG" in os.environ and not force: +def _setup_nccl_logging(log_file: Path) -> None: + if "NCCL_DEBUG" in os.environ: return nccl_log_file = log_file.parent.joinpath("nccl", log_file.name) - try: - nccl_log_file.parent.mkdir(parents=True, exist_ok=True) - except OSError as ex: - raise RuntimeError( - f"The NCCL log directory ({nccl_log_file.parent}) cannot be created. See nested exception for details." - ) from ex + nccl_log_file.parent.mkdir(parents=True, exist_ok=True) os.environ["NCCL_DEBUG"] = "INFO" os.environ["NCCL_DEBUG_FILE"] = str(nccl_log_file) diff --git a/src/fairseq2/recipes/mt/__init__.py b/src/fairseq2/recipes/mt/__init__.py index abbe74fd4..a29f28b72 100644 --- a/src/fairseq2/recipes/mt/__init__.py +++ b/src/fairseq2/recipes/mt/__init__.py @@ -10,11 +10,10 @@ from fairseq2.recipes.mt.eval import load_mt_evaluator, mt_eval_presets from fairseq2.recipes.mt.train import load_mt_trainer, mt_train_presets from fairseq2.recipes.mt.translate import load_text_translator, text_translate_presets -from fairseq2.recipes.utils.sweep import default_sweep_tagger def _setup_mt_cli(cli: Cli) -> None: - default_sweep_tagger.extend_allow_set("source_lang", "target_lang") + extra_sweep_keys = {"source_lang", "target_lang"} group = cli.add_group("mt", help="machine translation recipes") @@ -23,6 +22,7 @@ def _setup_mt_cli(cli: Cli) -> None: loader=load_mt_trainer, preset_configs=mt_train_presets, default_preset="nllb_dense_600m", + extra_sweep_keys=extra_sweep_keys, ) group.add_command( @@ -36,6 +36,7 @@ def _setup_mt_cli(cli: Cli) -> None: loader=load_mt_evaluator, preset_configs=mt_eval_presets, default_preset="nllb_dense_600m", + extra_sweep_keys=extra_sweep_keys, ) group.add_command( @@ -49,6 +50,7 @@ def _setup_mt_cli(cli: Cli) -> None: loader=load_text_translator, preset_configs=text_translate_presets, default_preset="nllb_dense_600m", + extra_sweep_keys=extra_sweep_keys, ) group.add_command( diff --git a/src/fairseq2/recipes/mt/common.py b/src/fairseq2/recipes/mt/common.py new file mode 100644 index 000000000..df00d9459 --- /dev/null +++ b/src/fairseq2/recipes/mt/common.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import final + +from torch import Tensor +from torch.nn import Module + +from fairseq2.models.encoder_decoder import EncoderDecoderModel +from fairseq2.models.seq2seq import Seq2SeqBatch, as_auto_regressive_input +from fairseq2.models.sequence import SequenceModelOutput +from fairseq2.recipes.common_metrics import Seq2SeqMetricBag +from fairseq2.recipes.utils.setup import check_model_type + + +@final +class MTCriterion: + _model: Module + _label_smoothing: float + + def __init__(self, model: Module, *, label_smoothing: float = 0.0) -> None: + check_model_type(model, EncoderDecoderModel) + + self._model = model + + self._label_smoothing = label_smoothing + + def __call__( + self, batch: Seq2SeqBatch, metric_bag: Seq2SeqMetricBag + ) -> tuple[Tensor, int]: + input_batch, target_batch = as_auto_regressive_input(batch) + + output = self._forward(input_batch) + + loss = output.compute_loss( + target_batch.seqs, label_smoothing=self._label_smoothing + ) + + metric_bag.update_nll_loss(input_batch, loss) + + metric_bag.update_batch_metrics(input_batch) + + return loss, batch.num_target_elements() + + def _forward(self, batch: Seq2SeqBatch) -> SequenceModelOutput: + return self._model(batch) # type: ignore[no-any-return] + + @property + def model(self) -> Module: + return self._model diff --git a/src/fairseq2/recipes/mt/eval.py b/src/fairseq2/recipes/mt/eval.py index 7094ad4a1..f42f2ea50 100644 --- a/src/fairseq2/recipes/mt/eval.py +++ b/src/fairseq2/recipes/mt/eval.py @@ -8,10 +8,10 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import List, Literal, Optional, TextIO, final +from typing import Any, TextIO, final import torch -from torch.nn import Module +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider @@ -21,42 +21,38 @@ from fairseq2.datasets.parallel_text import ( Direction, GenericParallelTextDataset, + ParallelTextReadOptions, load_parallel_text_dataset, ) from fairseq2.gang import Gang -from fairseq2.generation import Seq2SeqGenerator -from fairseq2.generation.text import SequenceToTextConverter +from fairseq2.generation import ( + BeamSearchConfig, + Seq2SeqGenerator, + SequenceToTextConverter, + create_seq2seq_generator, +) from fairseq2.logging import get_log_writer from fairseq2.metrics.text import BleuMetric, ChrfMetric from fairseq2.models import load_model from fairseq2.models.encoder_decoder import EncoderDecoderModel -from fairseq2.models.seq2seq import Seq2SeqBatch, as_auto_regressive_input -from fairseq2.models.sequence import SequenceModelOutput +from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.recipes.common_metrics import Seq2SeqGenerationMetricBag, Seq2SeqMetricBag from fairseq2.recipes.evaluator import AbstractEvalUnit, Evaluator, EvalUnit -from fairseq2.recipes.mt.translate import ( - BeamSearchConfig, - SamplingConfig, - _create_sequence_generator, -) +from fairseq2.recipes.mt.common import MTCriterion from fairseq2.recipes.utils.asset import ( AssetReference, asset_as_path, retrieve_asset_card, ) from fairseq2.recipes.utils.log import log_model -from fairseq2.recipes.utils.setup import ( - broadcast_model, - check_model_type, - setup_root_gang, -) -from fairseq2.typing import META, DataType, override +from fairseq2.recipes.utils.setup import broadcast_model, setup_root_gang +from fairseq2.typing import META, DataType from fairseq2.utils.profiler import Stopwatch log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class MTEvalConfig: """Holds the configuration of a machine translation evaluation task.""" @@ -67,6 +63,9 @@ class MTEvalConfig: split: str = "test" """The name of the test data split.""" + min_seq_len: int = 1 + """The maximum sequence length.""" + max_seq_len: int = 512 """The maximum sequence length.""" @@ -80,28 +79,30 @@ class MTEvalConfig: model: AssetReference = "nllb-200_dense_distill_600m" """The name of the model to evaluate.""" - checkpoint_dir: Optional[Path] = None + checkpoint_dir: Path | None = None """The checkpoint directory containing models saved by :class:`FileCheckpointManager`.""" dtype: DataType = torch.float16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + # Loss label_smoothing: float = 0.1 """The amount of label smoothing to apply while computing the loss.""" # BLEU/chrF++ - generator_mode: Literal["beam_search", "sampling"] = "beam_search" - """The mode of sequence generation.""" - - beam_search: BeamSearchConfig = field(default_factory=lambda: BeamSearchConfig()) - """The configuration for beam search-based sequence generation.""" + generator: str = "beam_search" + """The sequence generator.""" - sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig()) - """The configuration for sampling-based sequence generation.""" + generator_config: Any = field( + default_factory=lambda: BeamSearchConfig(max_gen_len=(1, 256), echo_prompt=True) + ) + """The configuration of the sequence generator.""" generator_batch_size: int = 8 - """The number of sentences per generator batch.""" + """The number of sentences per batch.""" # Misc seed: int = 2 @@ -118,6 +119,7 @@ def _nllb_dense_600m() -> MTEvalConfig: return MTEvalConfig() +@torch.inference_mode() def load_mt_evaluator( config: MTEvalConfig, output_dir: Path ) -> Evaluator[Seq2SeqBatch]: @@ -131,8 +133,6 @@ def load_mt_evaluator( gang = setup_root_gang(log) - seed = config.seed - model_card = retrieve_asset_card(config.model) # Load the tokenizer. @@ -174,7 +174,10 @@ def load_mt_evaluator( "The model cannot be initialized. See nested exception for details." ) from ex - check_model_type(model, EncoderDecoderModel) + if not isinstance(model, EncoderDecoderModel): + raise ValueError( + f"The model must be of type `{EncoderDecoderModel}`, but is of type `{type(model)}` instead." + ) gang.barrier() @@ -187,38 +190,53 @@ def load_mt_evaluator( log_model(model, log) # Initialize the sequence generator. - generator = _create_sequence_generator( - model, config.generator_mode, config.beam_search, config.sampling # type: ignore[arg-type] - ) + try: + generator = create_seq2seq_generator( + config.generator, model, config.generator_config + ) + except ValueError as ex: + raise ValueError( + "The sequence generator cannot be created. See nested exception for details." + ) from ex + + # Initialize the criterion. + criterion = MTCriterion(model, label_smoothing=config.label_smoothing) # Initialize the evaluation units. - units: List[EvalUnit[Seq2SeqBatch]] = [] + units: list[EvalUnit[Seq2SeqBatch]] = [] + + seed = config.seed data_readers = [] for direction in dataset.directions(config.split): # Loss Evaluation - loss_unit = MTLossEvalUnit( - model, - direction, - gang, - label_smoothing=config.label_smoothing, - ) + loss_unit = MTLossEvalUnit(criterion, direction, gang) units.append(loss_unit) - data_reader = dataset.create_reader( - config.split, - tokenizer, - gang, - config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), + options = ParallelTextReadOptions( direction=direction, - sync_batches=False, + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed, ) + try: + data_reader = dataset.create_reader( + config.split, + tokenizer, + gang, + config.min_seq_len, + config.max_seq_len, + LengthBatching(config.max_num_tokens), + options, + ) + except ValueError as ex: + raise ValueError( + f"The data reader for '{direction}' cannot be initialized. See nested exception for details." + ) from ex + seed += 1 data_readers.append(data_reader) @@ -276,18 +294,28 @@ def load_mt_evaluator( units.append(score_unit) - data_reader = dataset.create_reader( - config.split, - tokenizer, - gang, - config.max_seq_len, - batching=StaticBatching(config.generator_batch_size), + options = ParallelTextReadOptions( direction=direction, - sync_batches=False, + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed, ) + try: + data_reader = dataset.create_reader( + config.split, + tokenizer, + gang, + config.min_seq_len, + config.max_seq_len, + StaticBatching(config.generator_batch_size), + options, + ) + except ValueError as ex: + raise ValueError( + f"The data reader for '{direction}' cannot be initialized. See nested exception for details." + ) from ex + seed += 1 data_readers.append(data_reader) @@ -297,6 +325,8 @@ def load_mt_evaluator( units=units, data_readers=data_readers, root_gang=gang, + dtype=config.dtype, + amp=config.amp, tb_dir=output_dir.joinpath("tb"), metrics_dir=output_dir.joinpath("metrics"), seed=seed, @@ -306,53 +336,21 @@ def load_mt_evaluator( @final class MTLossEvalUnit(AbstractEvalUnit[Seq2SeqBatch]): - """Represents a machine translation loss evaluation unit.""" - - _label_smoothing: float + _criterion: MTCriterion _metric_bag: Seq2SeqMetricBag def __init__( - self, - model: Module, - direction: Direction, - gang: Gang, - *, - label_smoothing: float = 0.0, + self, criterion: MTCriterion, direction: Direction, gang: Gang ) -> None: - """ - :param model: - The encoder-decoder model. Might be wrapped with DDP or FSDP. - :param direction: - The language direction to evaluate. - :param gang: - The gang for distributed evaluation. - :param label_smoothing: - The amount of label smoothing to apply while computing the loss. - """ - super().__init__(model, display_name=f"loss/{direction}") + super().__init__(criterion.model, display_name=f"loss/{direction}") - check_model_type(model, EncoderDecoderModel) - - self._label_smoothing = label_smoothing + self._criterion = criterion self._metric_bag = Seq2SeqMetricBag(gang, train=False) @override def __call__(self, batch: Seq2SeqBatch) -> None: - input_batch, target_batch = as_auto_regressive_input(batch) - - output = self._forward(input_batch) - - loss = output.compute_loss( - target_batch.seqs, label_smoothing=self._label_smoothing - ) - - self._metric_bag.update_nll_loss(input_batch, loss.detach()) - - self._metric_bag.update_batch_metrics(input_batch) - - def _forward(self, batch: Seq2SeqBatch) -> SequenceModelOutput: - return self._model(batch) # type: ignore[no-any-return] + self._criterion(batch, self._metric_bag) @property @override @@ -365,9 +363,9 @@ class MTBleuChrfEvalUnit(AbstractEvalUnit[Seq2SeqBatch]): """Represents a machine translation BLEU/chrF++ evaluation unit.""" _converter: SequenceToTextConverter - _src_output_stream: Optional[TextIO] - _ref_output_stream: Optional[TextIO] - _hyp_output_stream: Optional[TextIO] + _src_output_stream: TextIO | None + _ref_output_stream: TextIO | None + _hyp_output_stream: TextIO | None _metric_bag: Seq2SeqGenerationMetricBag def __init__( @@ -377,9 +375,9 @@ def __init__( tokenizer: TextTokenizer, gang: Gang, *, - src_output_stream: Optional[TextIO] = None, - ref_output_stream: Optional[TextIO] = None, - hyp_output_stream: Optional[TextIO] = None, + src_output_stream: TextIO | None = None, + ref_output_stream: TextIO | None = None, + hyp_output_stream: TextIO | None = None, ) -> None: """ :param direction: diff --git a/src/fairseq2/recipes/mt/train.py b/src/fairseq2/recipes/mt/train.py index c5f92aff4..5b71cc60c 100644 --- a/src/fairseq2/recipes/mt/train.py +++ b/src/fairseq2/recipes/mt/train.py @@ -8,11 +8,11 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Literal, Optional, Tuple, final +from typing import Any, Literal, final import torch from torch import Tensor -from torch.nn import Module +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager @@ -21,24 +21,21 @@ from fairseq2.datasets import LengthBatching, StaticBatching from fairseq2.datasets.parallel_text import ( GenericParallelTextDataset, + ParallelTextReadOptions, load_parallel_text_dataset, ) from fairseq2.gang import Gang +from fairseq2.generation import BeamSearchConfig, create_seq2seq_generator from fairseq2.logging import get_log_writer from fairseq2.models import create_model from fairseq2.models.encoder_decoder import EncoderDecoderModel -from fairseq2.models.seq2seq import Seq2SeqBatch, as_auto_regressive_input -from fairseq2.models.sequence import SequenceModelOutput -from fairseq2.optim import AdamW -from fairseq2.optim.lr_scheduler import MyleLR +from fairseq2.models.seq2seq import Seq2SeqBatch +from fairseq2.optim import AdamWConfig, create_optimizer +from fairseq2.optim.lr_scheduler import MyleLRConfig, create_lr_scheduler from fairseq2.recipes.common_metrics import Seq2SeqMetricBag from fairseq2.recipes.evaluator import EvalUnit +from fairseq2.recipes.mt.common import MTCriterion from fairseq2.recipes.mt.eval import MTBleuChrfEvalUnit, MTLossEvalUnit -from fairseq2.recipes.mt.translate import ( - BeamSearchConfig, - SamplingConfig, - _create_sequence_generator, -) from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer from fairseq2.recipes.utils.asset import ( AssetReference, @@ -46,18 +43,15 @@ retrieve_asset_card, ) from fairseq2.recipes.utils.log import log_model, log_model_config -from fairseq2.recipes.utils.setup import ( - check_model_type, - setup_root_gang, - to_data_parallel, -) -from fairseq2.typing import META, DataType, override +from fairseq2.recipes.utils.setup import setup_root_gang, to_data_parallel +from fairseq2.typing import CPU, META, DataType from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class MTTrainConfig: """Holds the configuration of a machine translation training task. @@ -75,6 +69,9 @@ class MTTrainConfig: valid_split: str = "valid" """The name of the valid data split.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 512 """The maximum sequence length.""" @@ -97,11 +94,11 @@ class MTTrainConfig: model_family: str = "transformer" """The family of the model.""" - model_arch: Optional[str] = "nllb_dense_600m" + model_arch: str | None = "nllb_dense_600m" """The architecture of the model.""" model_config: Any = None - """The model configuration.""" + """The configuration of the model.""" dtype: DataType = torch.float16 """The data type of the model.""" @@ -113,22 +110,26 @@ class MTTrainConfig: """The granularity at which to wrap the ASR model.""" # Optimizer, LR, and Loss - lr: float = 0.001 - """The initial (post-warm-up) learning rate.""" + optimizer: str = "adamw" + """The optimizer.""" - start_lr: float = 1e-7 - """The initial warm-up learning rate.""" + optimizer_config: Any = field( + default_factory=lambda: AdamWConfig(lr=0.001, betas=(0.9, 0.98)) + ) + """The configuration of the optimizer.""" - num_lr_warmup_steps: int = 8000 - """The number of learning rate warm-up steps.""" + lr_scheduler: str = "myle" + """The learning rate scheduler.""" - betas: Tuple[float, float] = (0.9, 0.98) - """The coefficients of AdamW.""" + lr_scheduler_config: Any = field( + default_factory=lambda: MyleLRConfig(start_lr=1e-7, num_warmup_steps=8000) + ) + """The configuration of the learning rate scheduler.""" - max_gradient_norm: Optional[float] = None + max_gradient_norm: float | None = None """The maximum gradient norm. If ``None``, no clipping will be applied.""" - fp16_loss_scale: Tuple[float, float] = (128.0, 0.0001) + fp16_loss_scale: tuple[float, float] = (128.0, 0.0001) """The initial and minimum loss scale for fp16 training.""" gradient_accumulation: int = 2 @@ -141,7 +142,7 @@ class MTTrainConfig: max_num_steps: int = 100_000 """The maximum number of steps to train for.""" - max_num_data_epochs: Optional[int] = None + max_num_data_epochs: int | None = None """The maximum number of data epochs to train for.""" validate_after_n_steps: int = 0 @@ -159,31 +160,30 @@ class MTTrainConfig: publish_metrics_every_n_steps: int = 200 """The step interval at which to publish metrics.""" - # Checkpointing - resume_checkpoint_dir: Optional[Path] = None + # Checkpoint + resume_checkpoint_dir: Path | None = None """If not ``None``, adds the specified path to the default asset store.""" # BLEU/chrF++ compute_bleu_chrf: bool = True """If ``True``, computes BLEU and chrF++ during validation.""" - generator_mode: Literal["beam_search", "sampling"] = "beam_search" - """The mode of sequence generation.""" + generator: str = "beam_search" + """The sequence generator.""" - beam_search: BeamSearchConfig = field(default_factory=lambda: BeamSearchConfig()) - """The configuration for beam search-based sequence generation.""" - - sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig()) - """The configuration for sampling-based sequence generation.""" + generator_config: Any = field( + default_factory=lambda: BeamSearchConfig(max_gen_len=(1, 256), echo_prompt=True) + ) + """The configuration of the sequence generator.""" generator_batch_size: int = 8 - """The number of sentences per generator batch.""" + """The number of sentences per batch.""" # Misc seed: int = 2 """The random number generator seed to use.""" - profile: Optional[Tuple[int, int]] = None + profile: tuple[int, int] | None = None """The number of steps that the PyTorch profiler should skip and then record.""" monitored_gang: bool = False @@ -202,9 +202,10 @@ class MTTrainConfig: def _nllb_dense_300m() -> MTTrainConfig: config = _nllb_dense_600m() + assert isinstance(config.lr_scheduler_config, MyleLRConfig) + config.model_arch = "nllb_dense_300m" - config.model_config = {"dropout_p": 0.3} - config.num_lr_warmup_steps = 400 + config.lr_scheduler_config.num_warmup_steps = 400 config.gradient_accumulation = 4 config.max_num_steps = 10_000 config.validate_every_n_steps = 1000 @@ -231,8 +232,6 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB CheckpointModelMetadataProvider(config.resume_checkpoint_dir) ) - seed = config.seed - tokenizer_card = retrieve_asset_card(config.tokenizer) # Load the tokenizer. @@ -259,7 +258,13 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB dataset = GenericParallelTextDataset.from_path(dataset_path) + seed = config.seed + # Initialize the model + manual_seed(seed, CPU, gang.device) + + seed += 1 + try: model, model_config = create_model( config.model_family, @@ -273,13 +278,15 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB "The model cannot be initialized. See nested exception for details." ) from ex - check_model_type(model, EncoderDecoderModel) + if not isinstance(model, EncoderDecoderModel): + raise ValueError( + f"The model must be of type `{EncoderDecoderModel}`, but is of type `{type(model)}` instead." + ) log_model_config(model_config, log) - checkpoint_manager.save_model_metadata( - family=model.family, config=model_config, tokenizer_name=tokenizer_card.name - ) + checkpoint_manager.save_model_metadata(family=model.family, config=model_config) + checkpoint_manager.save_tokenizer_metadata(tokenizer_card.name) has_checkpoint = checkpoint_manager.has_checkpoint() @@ -288,7 +295,6 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB gang, config.data_parallelism, log, - fsdp_skip_init=has_checkpoint, fsdp_broadcast_state=not has_checkpoint, fsdp_mixed_precision_dtype=config.dtype, fsdp_fp32_reduce=True, @@ -297,15 +303,13 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB log_model(dp_model, log, rank=gang.rank) - # Initialize the train unit and the optimizer. - unit = MTTrainUnit(dp_model, gang, label_smoothing=config.label_smoothing) + # Initialize the criterion. + criterion = MTCriterion(dp_model, label_smoothing=config.label_smoothing) - data_reader = dataset.create_reader( - config.split, - tokenizer, - gang, - config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), + # Initialize the train unit. + unit = MTTrainUnit(criterion, gang) + + options = ParallelTextReadOptions( sample=True, example_shuffle_window=config.example_shuffle_window, batch_shuffle_window=config.batch_shuffle_window, @@ -314,52 +318,92 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB seed=seed, ) + try: + data_reader = dataset.create_reader( + config.split, + tokenizer, + gang, + config.min_seq_len, + config.max_seq_len, + LengthBatching(config.max_num_tokens), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + seed += 1 - optimizer = AdamW(dp_model.parameters(), lr=config.lr, betas=config.betas) + # Initialize the optimizer. + try: + optimizer = create_optimizer( + config.optimizer, dp_model, config.optimizer_config + ) + except ValueError as ex: + raise ValueError( + "The optimizer cannot be created. See nested exception for details." + ) from ex - lr_scheduler = MyleLR( - optimizer, - num_warmup_steps=config.num_lr_warmup_steps, - start_lr=config.start_lr, - ) + # Initialize the learning rate scheduler. + try: + lr_scheduler = create_lr_scheduler( + config.lr_scheduler, + optimizer, + config.lr_scheduler_config, + max_num_steps=config.max_num_steps, + ) + except ValueError as ex: + raise ValueError( + "The learning rate scheduler cannot be created. See nested exception for details." + ) from ex # Initialize the sequence generator. if config.compute_bleu_chrf: - generator = _create_sequence_generator( - model, config.generator_mode, config.beam_search, config.sampling # type: ignore[arg-type] - ) + try: + generator = create_seq2seq_generator( + config.generator, model, config.generator_config + ) + except ValueError as ex: + raise ValueError( + "The sequence generator cannot be created. See nested exception for details." + ) from ex else: generator = None # Initialize the validation units. - valid_units: List[EvalUnit[Seq2SeqBatch]] = [] + valid_units: list[EvalUnit[Seq2SeqBatch]] = [] valid_data_readers = [] for direction in dataset.directions(config.valid_split): # Loss Validation - valid_loss_unit = MTLossEvalUnit( - dp_model, - direction, - gang, - label_smoothing=config.label_smoothing, - ) + valid_loss_unit = MTLossEvalUnit(criterion, direction, gang) valid_units.append(valid_loss_unit) - valid_data_reader = dataset.create_reader( - config.valid_split, - tokenizer, - gang, - config.max_seq_len, - batching=LengthBatching(config.max_num_tokens), + options = ParallelTextReadOptions( direction=direction, - sync_batches=False, + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed, ) + try: + valid_data_reader = dataset.create_reader( + config.valid_split, + tokenizer, + gang, + config.min_seq_len, + config.max_seq_len, + LengthBatching(config.max_num_tokens), + options, + ) + except ValueError as ex: + raise ValueError( + f"The data reader for '{direction}' cannot be initialized. See nested exception for details." + ) from ex + seed += 1 valid_data_readers.append(valid_data_reader) @@ -372,22 +416,35 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB valid_units.append(valid_score_unit) - valid_data_reader = dataset.create_reader( - config.valid_split, - tokenizer, - gang, - config.max_seq_len, - batching=StaticBatching(config.generator_batch_size), + options = ParallelTextReadOptions( direction=direction, - sync_batches=False, + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed, ) + try: + valid_data_reader = dataset.create_reader( + config.valid_split, + tokenizer, + gang, + config.min_seq_len, + config.max_seq_len, + StaticBatching(config.generator_batch_size), + options, + ) + except ValueError as ex: + raise ValueError( + f"The data reader for '{direction}' cannot be initialized. See nested exception for details." + ) from ex + seed += 1 valid_data_readers.append(valid_data_reader) + # TODO: Fix once we support static mixed precision on one device. + amp = gang.size == 1 or config.data_parallelism != "fsdp" + # Initialize the trainer. return Trainer[Seq2SeqBatch]( unit=unit, @@ -398,6 +455,7 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB lr_scheduler=lr_scheduler, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, score_metric_name="chrf" if config.compute_bleu_chrf else None, @@ -420,44 +478,19 @@ def load_mt_trainer(config: MTTrainConfig, output_dir: Path) -> Trainer[Seq2SeqB @final class MTTrainUnit(AbstractTrainUnit[Seq2SeqBatch]): - """Represents a machine translation training unit.""" - - _label_smoothing: float + _criterion: MTCriterion _metric_bag: Seq2SeqMetricBag - def __init__( - self, - model: Module, - gang: Gang, - *, - label_smoothing: float = 0.0, - ) -> None: - super().__init__(model) + def __init__(self, criterion: MTCriterion, gang: Gang) -> None: + super().__init__(criterion.model) - check_model_type(model, EncoderDecoderModel) - - self._label_smoothing = label_smoothing + self._criterion = criterion self._metric_bag = Seq2SeqMetricBag(gang) @override - def __call__(self, batch: Seq2SeqBatch) -> Tuple[Tensor, int]: - input_batch, target_batch = as_auto_regressive_input(batch) - - output = self._forward(input_batch) - - loss = output.compute_loss( - target_batch.seqs, label_smoothing=self._label_smoothing - ) - - self._metric_bag.update_nll_loss(input_batch, loss.detach()) - - self._metric_bag.update_batch_metrics(input_batch) - - return loss, batch.num_target_elements() - - def _forward(self, batch: Seq2SeqBatch) -> SequenceModelOutput: - return self._model(batch) # type: ignore[no-any-return] + def __call__(self, batch: Seq2SeqBatch) -> tuple[Tensor, int]: + return self._criterion(batch, self._metric_bag) @property @override diff --git a/src/fairseq2/recipes/mt/translate.py b/src/fairseq2/recipes/mt/translate.py index f0485c2cb..4965d2020 100644 --- a/src/fairseq2/recipes/mt/translate.py +++ b/src/fairseq2/recipes/mt/translate.py @@ -8,27 +8,28 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Literal, Optional, TextIO, Tuple, final +from typing import Any, TextIO, final import torch +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import TextTokenizer, load_text_tokenizer from fairseq2.datasets import StaticBatching -from fairseq2.datasets.text import GenericTextDataset, load_text_dataset +from fairseq2.datasets.text import ( + GenericTextDataset, + TextReadOptions, + load_text_dataset, +) from fairseq2.gang import Gang from fairseq2.generation import ( - BeamSearchSeq2SeqGenerator, - Sampler, - SamplingSeq2SeqGenerator, + BeamSearchConfig, Seq2SeqGenerator, - StandardBeamSearchAlgorithm, - TopKSampler, - TopPSampler, + SequenceToTextConverter, + create_seq2seq_generator, ) -from fairseq2.generation.text import SequenceToTextConverter from fairseq2.logging import get_log_writer from fairseq2.models import load_model from fairseq2.models.encoder_decoder import EncoderDecoderModel @@ -41,18 +42,14 @@ retrieve_asset_card, ) from fairseq2.recipes.utils.log import log_model -from fairseq2.recipes.utils.setup import ( - broadcast_model, - check_model_type, - setup_root_gang, -) -from fairseq2.typing import META, DataType, override +from fairseq2.recipes.utils.setup import broadcast_model, setup_root_gang +from fairseq2.typing import META, DataType from fairseq2.utils.profiler import Stopwatch log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class TextTranslateConfig: """Holds the configuration of a text translation task.""" @@ -66,6 +63,9 @@ class TextTranslateConfig: target_lang: str = "deu_Latn" """The code of the language to translate to.""" + min_seq_len: int = 1 + """The minimum sequence length.""" + max_seq_len: int = 512 """The maximum sequence length.""" @@ -79,115 +79,29 @@ class TextTranslateConfig: model: AssetReference = "nllb-200_dense_distill_600m" """The name of the model to translate with.""" - checkpoint_dir: Optional[Path] = None + checkpoint_dir: Path | None = None """The checkpoint directory containing models saved by :class:`FileCheckpointManager`.""" dtype: DataType = torch.float16 """The data type of the model.""" - # Generation - generator_mode: Literal["beam_search", "sampling"] = "beam_search" - """The mode of sequence generation.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" - beam_search: BeamSearchConfig = field(default_factory=lambda: BeamSearchConfig()) - """The configuration for beam search-based sequence generation.""" + # Generation + generator: str = "beam_search" + """The sequence generator.""" - sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig()) - """The configuration for sampling-based sequence generation.""" + generator_config: Any = field( + default_factory=lambda: BeamSearchConfig(max_gen_len=(1, 256), echo_prompt=True) + ) + """The configuration of the sequence generator.""" # Misc seed: int = 2 """The random number generator seed to use.""" -@dataclass -class BeamSearchConfig: - """Holds the configuration for beam search-based sequence generation. - - See :class:`BeamSearchSeq2SeqGenerator` for more info. - """ - - algorithm: Literal["standard"] = "standard" - """The beam search algorithm.""" - - beam_size: int = 5 - """The beam size.""" - - min_gen_len: int = 1 - """The minimum generation length.""" - - max_gen_len: Tuple[int, int] = (1, 128) - """The maximum generation length.""" - - max_seq_len: Optional[int] = None - """The maximum sequence length including prompt.""" - - normalize_scores: bool = True - """If ``True``, normalizes scores by lengths of generated sequences.""" - - temperature: float = 1.0 - """The logit temperature.""" - - unk_penalty: float = 0.0 - """The UNK symbol penalty.""" - - len_penalty: float = 1.0 - """The length penalty.""" - - prefill_chunk_size: Optional[int] = 512 - """The prefill will be performed incrementally by chunks of this size.""" - - decode_capacity_increment: Optional[int] = 16 - """The sequence length capacity will be incremented by multiplies of this value.""" - - -@dataclass -class SamplingConfig: - """Holds the configuration for sampling-based sequence generation. - - See :class:`SamplingSeq2SeqGenerator` for more info. - """ - - sampler: Literal["top-p", "top-k"] = "top-p" - """The sampling algorithm.""" - - top_p: float = 1.0 - """The cumulative probability threshold for top-p sampling.""" - - top_k = 1 - """The number of top candidates to select from for top-k sampling.""" - - min_gen_len: int = 1 - """The minimum generation length.""" - - max_gen_len: Tuple[int, int] = (1, 128) - """The maximum generation length.""" - - max_seq_len: Optional[int] = None - """The maximum sequence length including prompt.""" - - compute_scores: bool = False - """If ``True``, computes scores of generated sequences.""" - - normalize_scores: bool = True - """If ``True``, normalizes scores by lengths of generated sequences.""" - - temperature: float = 1.0 - """The logit temperature.""" - - unk_penalty: float = 0.0 - """The UNK symbol penalty.""" - - len_penalty: float = 1.0 - """The length penalty.""" - - prefill_chunk_size: Optional[int] = 512 - """The prefill will be performed incrementally by chunks of this size.""" - - decode_capacity_increment: Optional[int] = 16 - """The sequence length capacity will be incremented by multiplies of this value.""" - - text_translate_presets = ConfigRegistry[TextTranslateConfig]() text_translate_preset = text_translate_presets.decorator @@ -198,6 +112,7 @@ def _nllb_dense_600m() -> TextTranslateConfig: return TextTranslateConfig() +@torch.inference_mode() def load_text_translator( config: TextTranslateConfig, output_dir: Path ) -> Generator[SequenceBatch]: @@ -211,8 +126,6 @@ def load_text_translator( gang = setup_root_gang(log) - seed = config.seed - model_card = retrieve_asset_card(config.model) # Load the tokenizer. @@ -254,7 +167,10 @@ def load_text_translator( "The model cannot be initialized. See nested exception for details." ) from ex - check_model_type(model, EncoderDecoderModel) + if not isinstance(model, EncoderDecoderModel): + raise ValueError( + f"The model must be of type `{EncoderDecoderModel}`, but is of type `{type(model)}` instead." + ) gang.barrier() @@ -267,9 +183,14 @@ def load_text_translator( log_model(model, log) # Initialize the sequence generator. - generator = _create_sequence_generator( - model, config.generator_mode, config.beam_search, config.sampling # type: ignore[arg-type] - ) + try: + generator = create_seq2seq_generator( + config.generator, model, config.generator_config + ) + except ValueError as ex: + raise ValueError( + "The sequence generator cannot be created. See nested exception for details." + ) from ex # Initialize the generator unit. src_output_file = output_dir.joinpath( @@ -309,17 +230,27 @@ def load_text_translator( task="translation", lang=config.source_lang, mode="source" ) - data_reader = dataset.create_reader( - text_encoder, - tokenizer.vocab_info.pad_idx, - gang, - config.max_seq_len, - batching=StaticBatching(config.batch_size), - sync_batches=False, - num_prefetch=config.num_prefetch, - seed=seed, + seed = config.seed + + options = TextReadOptions( + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed ) + try: + data_reader = dataset.create_reader( + text_encoder, + tokenizer.vocab_info.pad_idx, + gang, + config.min_seq_len, + config.max_seq_len, + StaticBatching(config.batch_size), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + seed += 1 # Initialize the generator. @@ -327,6 +258,8 @@ def load_text_translator( unit=unit, data_reader=data_reader, root_gang=gang, + dtype=config.dtype, + amp=config.amp, metrics_dir=output_dir.joinpath("metrics"), seed=seed, wall_watch=wall_watch, @@ -412,77 +345,3 @@ def __call__(self, batch: SequenceBatch) -> None: @override def metric_bag(self) -> Seq2SeqGenerationMetricBag: return self._metric_bag - - -def _create_sequence_generator( - model: EncoderDecoderModel, - mode: str, - beam_search_config: BeamSearchConfig, - sampling_config: SamplingConfig, -) -> Seq2SeqGenerator: - if mode == "beam_search": - return _create_beam_search_generator(model, beam_search_config) - - if mode == "sampling": - return _create_sampling_generator(model, sampling_config) - - raise ValueError( - f"`generator_mode` must be 'sampling' or 'beam_search', but is '{mode}' instead." - ) - - -def _create_beam_search_generator( - model: EncoderDecoderModel, config: BeamSearchConfig -) -> BeamSearchSeq2SeqGenerator: - if config.algorithm == "standard": - algorithm = StandardBeamSearchAlgorithm() - else: - raise ValueError( - f"`beam_search.algorithm` must be 'standard', but is '{config.algorithm}' instead." - ) - - return BeamSearchSeq2SeqGenerator( - model, - algorithm=algorithm, - beam_size=config.beam_size, - min_gen_len=config.min_gen_len, - max_gen_len=config.max_gen_len, - echo_prompt=True, - normalize_scores=config.normalize_scores, - temperature=config.temperature, - unk_penalty=config.unk_penalty, - len_penalty=config.len_penalty, - prefill_chunk_size=config.prefill_chunk_size, - decode_capacity_increment=config.decode_capacity_increment, - ) - - -def _create_sampling_generator( - model: EncoderDecoderModel, config: SamplingConfig -) -> SamplingSeq2SeqGenerator: - sampler: Sampler - - if config.sampler == "top-p": - sampler = TopPSampler(config.top_p) - elif config.sampler == "top-k": - sampler = TopKSampler(config.top_k) - else: - raise ValueError( - f"`sampling.sampler` must be 'top-p' or 'top-k', but is '{config.sampler}' instead." - ) - - return SamplingSeq2SeqGenerator( - model, - sampler, - min_gen_len=config.min_gen_len, - max_gen_len=config.max_gen_len, - max_seq_len=config.max_seq_len, - echo_prompt=True, - compute_scores=config.compute_scores, - normalize_scores=config.normalize_scores, - temperature=config.temperature, - unk_penalty=config.unk_penalty, - len_penalty=config.len_penalty, - prefill_chunk_size=config.prefill_chunk_size, - decode_capacity_increment=config.decode_capacity_increment, - ) diff --git a/src/fairseq2/recipes/runner.py b/src/fairseq2/recipes/runner.py new file mode 100644 index 000000000..a92c03460 --- /dev/null +++ b/src/fairseq2/recipes/runner.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Hashable, Set +from functools import cache +from itertools import chain +from pathlib import Path +from signal import SIGUSR1, signal +from types import FrameType +from typing import ( + Mapping, + Protocol, + Sequence, + TypeAlias, + TypeVar, + final, + runtime_checkable, +) + +from typing_extensions import override + +from fairseq2.config_registry import ConfigProvider +from fairseq2.error import ContractError, SetupError +from fairseq2.logging import log +from fairseq2.recipes.cluster import ClusterResolver +from fairseq2.recipes.logging import LoggingInitializer +from fairseq2.recipes.utils.log import log_config +from fairseq2.recipes.utils.sweep_tagger import SweepTagger +from fairseq2.utils.file import FileSystem +from fairseq2.utils.structured import ( + StructureError, + merge_unstructured, + structure, + unstructure, +) +from fairseq2.utils.yaml import YamlDumper, YamlError, YamlLoader + + +class RecipeRunner(ABC): + @abstractmethod + def run(self, config: object, output_dir: Path) -> None: + ... + + +Recipe: TypeAlias = Callable[[], None] + + +ConfigT_contra = TypeVar("ConfigT_contra", contravariant=True) + + +class RecipeLoader(Protocol[ConfigT_contra]): + def __call__(self, config: ConfigT_contra, output_dir: Path) -> Recipe: + ... + + +ConfigT = TypeVar("ConfigT") + + +@final +class StandardRecipeRunner(RecipeRunner): + _loader: RecipeLoader[object] + _signal_handler: SignalHandler + + def __init__( + self, loader: RecipeLoader[object], signal_handler: SignalHandler + ) -> None: + self._loader = loader + self._signal_handler = signal_handler + + @override + def run(self, config: object, output_dir: Path) -> None: + recipe = self._loader(config, output_dir) + + # If the recipe is stoppable, use SIGUSR1 as the stop signal. + if isinstance(recipe, Stoppable): + + def request_stop(nr: int) -> None: + log.info("SIGUSR1 received. Requesting recipe to stop.") + + recipe.request_stop() + + self._signal_handler.set(SIGUSR1, request_stop) + + recipe() + + +@runtime_checkable +class Stoppable(Protocol): + """Represents a task that supports graceful stopping.""" + + def request_stop(self) -> None: + ... + + +class SignalHandler(ABC): + @abstractmethod + def set(self, nr: int, callback: Callable[[int], None]) -> None: + ... + + +@final +class SystemSignalHandler(SignalHandler): + @override + def set(self, nr: int, callback: Callable[[int], None]) -> None: + def cb(signum: int, frame: FrameType | None) -> None: + callback(signum) + + signal(nr, cb) + + +class EnvironmentBootstrapper(ABC): + @abstractmethod + def run( + self, + preset: str, + config: object, + output_dir: Path, + *, + cluster: str = "auto", + sweep_format: str | None = None, + debug: bool = False, + ) -> Path: + ... + + +@final +class StandardEnvironmentBootstrapper(EnvironmentBootstrapper): + _cluster_resolver: ClusterResolver + _sweep_tagger: SweepTagger + _file_system: FileSystem + _logging_initializer: LoggingInitializer + _yaml_dumper: YamlDumper + + def __init__( + self, + cluster_resolver: ClusterResolver, + sweep_tagger: SweepTagger, + file_system: FileSystem, + logging_initializer: LoggingInitializer, + yaml_dumper: YamlDumper, + ) -> None: + self._cluster_resolver = cluster_resolver + self._sweep_tagger = sweep_tagger + self._file_system = file_system + self._logging_initializer = logging_initializer + self._yaml_dumper = yaml_dumper + + @override + def run( + self, + preset: str, + config: object, + output_dir: Path, + *, + cluster: str = "auto", + sweep_format: str | None = None, + debug: bool = False, + ) -> Path: + cluster_handler = self._cluster_resolver.get(cluster) + + world_size, rank = cluster_handler.set_torch_distributed_variables() + + unstructured_config = unstructure(config) + + sweep_tag = self._sweep_tagger.generate( + world_size, preset, unstructured_config, sweep_format + ) + + sweep_output_dir = output_dir.joinpath(sweep_tag) + + try: + self._file_system.make_directory(sweep_output_dir) + except OSError as ex: + raise SetupError( + f"The '{sweep_output_dir}' recipe output directory cannot be created. See the nested exception for details." + ) from ex + + self._logging_initializer.initialize( + sweep_output_dir.joinpath("logs/rank_{rank}.log"), debug=debug + ) + + log.info("The log files stored under the '{}' directory.", sweep_output_dir) + + log_config(unstructured_config, log) + + if rank == 0: + config_file = sweep_output_dir.joinpath("config.yaml") + + try: + self._yaml_dumper(unstructured_config, config_file) + except (OSError, YamlError) as ex: + raise SetupError( + f"The recipe configuration cannot be saved to the '{config_file}' file. See the nested exception for details." + ) from ex + + return sweep_output_dir + + +class ConfigReader(ABC): + @abstractmethod + def read( + self, + preset: str, + config_files: Sequence[Sequence[Path]] | None, + config_overrides: Sequence[Mapping[str, object]] | None, + ) -> object: + ... + + +@final +class StandardConfigReader(ConfigReader): + _preset_configs: ConfigProvider[object] + _file_system: FileSystem + _yaml_loader: YamlLoader + + def __init__( + self, + preset_configs: ConfigProvider[object], + file_system: FileSystem, + yaml_loader: YamlLoader, + ) -> None: + self._preset_configs = preset_configs + self._file_system = file_system + self._yaml_loader = yaml_loader + + @override + def read( + self, + preset: str, + config_files: Sequence[Sequence[Path]] | None, + config_overrides: Sequence[Mapping[str, object]] | None, + ) -> object: + # Load the preset configuration. + preset_config = self._preset_configs.get(preset) + + try: + unstructured_config = unstructure(preset_config) + except StructureError as ex: + raise ContractError( + f"The '{preset}' preset configuration cannot be unstructured. See the nested exception for details." + ) from ex + + # Update the configuration with `--config-file`. + if config_files: + for config_file in chain.from_iterable(config_files): + if not self._file_system.is_file(config_file): + raise ConfigFileNotFoundError(config_file) + + try: + unstructured_config_overrides = self._yaml_loader(config_file) + except YamlError as ex: + raise StructureError( + f"The '{config_file}' configuration file cannot be merged with the preset configuration. See the nested exception for details." + ) from ex + except OSError as ex: + raise SetupError( + f"The '{config_file}' configuration file cannot be read. See the nested exception for details." + ) from ex + + try: + unstructured_config = merge_unstructured( + unstructured_config, unstructured_config_overrides[0] + ) + except StructureError as ex: + raise StructureError( + f"The '{config_file}' configuration file cannot be merged with the preset configuration. See the nested exception for details." + ) from ex + + # Update the configuration with `--config`. + if config_overrides: + for overrides in config_overrides: + try: + unstructured_config = merge_unstructured( + unstructured_config, overrides + ) + except StructureError as ex: + raise StructureError( + "The command line configuration overrides cannot be merged with the preset recipe configuration. See the nested exception for details." + ) from ex + + return structure(unstructured_config, self._preset_configs.config_kls) # type: ignore[no-any-return] + + +class ConfigFileNotFoundError(Exception): + config_file: Path + + def __init__(self, config_file: Path) -> None: + super().__init__( + f"The '{config_file}' path does not point to a configuration file." + ) + + self.config_file = config_file + + +def get_sweep_keys(extra_sweep_keys: Set[Hashable] | None) -> Set[Hashable]: + sweep_keys = get_default_sweep_keys() + + if extra_sweep_keys is not None: + sweep_keys = sweep_keys | extra_sweep_keys + + return sweep_keys + + +@cache +def get_default_sweep_keys() -> Set[Hashable]: + return { + "batch_shuffle_window", + "betas", + "data_parallelism", + "dataset", + "dtype", + "example_shuffle_window", + "final_lr_ratio", + "final_lr_scale", + "fp16_loss_scale", + "fsdp_reshard_after_forward", + "fsdp_wrap_granularity", + "gradient_accumulation", + "label_smoothing", + "lr", + "lr_stage_ratios", + "max_gradient_norm", + "max_num_elements", + "max_num_steps", + "max_num_tokens", + "max_seq_len", + "mixed_precision", + "model", + "model_arch", + "model_config", + "num_lr_warmup_steps", + "pretrained_model", + "seed", + "split", + "start_lr", + "start_lr_scale", + "tensor_parallel_size", + "tokenizer", + "train_split", + "valid_split", + "weight_decay", + } diff --git a/src/fairseq2/recipes/trainer.py b/src/fairseq2/recipes/trainer.py index d6c62cb13..85adc9d3e 100644 --- a/src/fairseq2/recipes/trainer.py +++ b/src/fairseq2/recipes/trainer.py @@ -7,26 +7,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from contextlib import nullcontext +from collections.abc import Iterable, Sequence +from contextlib import AbstractContextManager, nullcontext from itertools import count from pathlib import Path from statistics import mean -from typing import ( - Any, - ContextManager, - Dict, - Generic, - List, - Optional, - Sequence, - Tuple, - TypeVar, - final, -) +from typing import Generic, TypeVar, final import torch import torch.distributed -from rich.progress import Progress +from rich.progress import Progress, TaskID from torch import Tensor from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Module @@ -34,18 +24,20 @@ from torch.optim import Optimizer from torch.profiler import record_function from torcheval.metrics import Mean +from typing_extensions import override from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError from fairseq2.datasets import DataReader -from fairseq2.early_stopper import EarlyStopper -from fairseq2.gang import FakeGang, Gang, all_sum, broadcast_flag -from fairseq2.logging import get_log_writer +from fairseq2.error import ContractError, InternalError, InvalidOperationError +from fairseq2.gang import FakeGang, Gang, broadcast_flag +from fairseq2.logging import log from fairseq2.metrics import ( JsonFileMetricRecorder, LogMetricRecorder, MetricBag, MetricRecorder, TensorBoardRecorder, + WandbRecorder, format_metric_value, record_metrics, ) @@ -57,17 +49,15 @@ ) from fairseq2.optim import DynamicLossScaler from fairseq2.optim.lr_scheduler import LRScheduler, NoopLR, get_effective_lr -from fairseq2.recipes.common_metrics import set_throughput_value +from fairseq2.recipes.common_metrics import extend_batch_metrics +from fairseq2.recipes.early_stopper import EarlyStopper, NoopEarlyStopper from fairseq2.recipes.evaluator import EvalUnit -from fairseq2.recipes.utils.cli import create_rich_progress -from fairseq2.typing import CPU, DataType, override +from fairseq2.recipes.utils.rich import create_rich_progress +from fairseq2.typing import CPU, DataType from fairseq2.utils.profiler import Profiler, Stopwatch from fairseq2.utils.rng import RngBag from fairseq2.utils.state import FSDPOptimizerStateHandler, StatefulObjectBag -log = get_log_writer(__name__) - - BatchT = TypeVar("BatchT") BatchT_contra = TypeVar("BatchT_contra", contravariant=True) @@ -77,11 +67,13 @@ class TrainUnit(ABC, Generic[BatchT_contra]): """Represents a unit to be used with :class:`Trainer`.""" @abstractmethod - def __call__(self, batch: BatchT_contra) -> Tuple[Tensor, int]: + def __call__(self, batch: BatchT_contra) -> tuple[Tensor, int | None]: """Process ``batch``. :returns: - The loss and the number of targets used to compute the loss. + - The loss. + - The number of targets used to compute the loss. If ``None``, the + model gradients won't be normalized. """ @abstractmethod @@ -130,41 +122,53 @@ class Trainer(StatefulObjectBag, Generic[BatchT]): _optimizer: Optimizer _lr_scheduler: LRScheduler _loss_scaler: DynamicLossScaler - _max_gradient_norm: Optional[float] + _max_gradient_norm: float | None + _amp: bool _step_nr: int - _max_num_steps: Optional[int] + _max_num_steps: int | None _data_epoch_nr: int - _max_num_data_epochs: Optional[int] - _eod: bool + _max_num_data_epochs: int | None + _repeat_step: bool + _read_data: bool + _num_effective_batches: int + _end_of_data_epoch: bool + _end_of_data: bool _should_stop: bool - _score_metric_name: Optional[str] + _score_metric_name: str | None _lower_better: bool - _early_stopper: Optional[EarlyStopper] - _best_step_and_score: Optional[Tuple[int, float]] - _valid_score: Optional[float] + _early_stopper: EarlyStopper | None + _best_step_and_score: tuple[int, float] | None + _valid_score: float | None _valid_units: Sequence[EvalUnit[BatchT]] _valid_data_readers: Sequence[DataReader[BatchT]] _validate_after_n_steps: int - _validate_every_n_steps: int + _validate_every_n_steps: int | None + _validate_after_n_data_epochs: int + _validate_every_n_data_epochs: int | None _checkpoint_manager: CheckpointManager _checkpoint_after_n_steps: int - _checkpoint_every_n_steps: Optional[int] - _keep_last_n_checkpoints: Optional[int] - _keep_best_n_checkpoints: Optional[int] - _keep_last_n_models: Optional[int] - _keep_best_n_models: Optional[int] + _checkpoint_every_n_steps: int | None + _checkpoint_after_n_data_epochs: int + _checkpoint_every_n_data_epochs: int | None + _keep_last_n_checkpoints: int | None + _keep_best_n_checkpoints: int | None + _keep_last_n_models: int | None + _keep_best_n_models: int | None _metric_bag: MetricBag - _metric_recorders: List[MetricRecorder] + _metric_recorders: list[MetricRecorder] _publish_metrics_after_n_steps: int - _publish_metrics_every_n_steps: int + _publish_metrics_every_n_steps: int | None + _publish_metrics_after_n_data_epochs: int + _publish_metrics_every_n_data_epochs: int | None _profiler: Profiler _anomaly_detection: bool _seed: int _rng_bag: RngBag _wall_watch: Stopwatch - _step_time: float + _total_step_time: float _run: bool _progress: Progress + _train_task_id: TaskID def __init__( self, @@ -175,32 +179,41 @@ def __init__( optimizer: Optimizer, checkpoint_manager: CheckpointManager, wall_watch: Stopwatch, + dp_gang: Gang | None = None, + tp_gang: Gang | None = None, dtype: DataType = torch.float32, - dp_gang: Optional[Gang] = None, - tp_gang: Optional[Gang] = None, - lr_scheduler: Optional[LRScheduler] = None, - fp16_loss_scale: Tuple[float, float] = (128.0, 0.0001), - max_gradient_norm: Optional[float] = None, - max_num_steps: Optional[int] = 1000, - max_num_data_epochs: Optional[int] = None, - score_metric_name: Optional[str] = None, + lr_scheduler: LRScheduler | None = None, + fp16_loss_scale: tuple[float, float] = (128.0, 0.0001), + max_gradient_norm: float | None = None, + amp: bool = False, + max_num_steps: int | None = None, + max_num_data_epochs: int | None = None, + score_metric_name: str | None = None, lower_better: bool = False, - early_stopper: Optional[EarlyStopper] = None, - valid_units: Optional[Sequence[EvalUnit[BatchT]]] = None, - valid_data_readers: Optional[Sequence[DataReader[BatchT]]] = None, + early_stopper: EarlyStopper | None = None, + valid_units: Sequence[EvalUnit[BatchT]] | None = None, + valid_data_readers: Sequence[DataReader[BatchT]] | None = None, validate_after_n_steps: int = 0, - validate_every_n_steps: int = 100, + validate_every_n_steps: int | None = None, + validate_after_n_data_epochs: int = 0, + validate_every_n_data_epochs: int | None = None, checkpoint_after_n_steps: int = 0, - checkpoint_every_n_steps: Optional[int] = None, - keep_last_n_checkpoints: Optional[int] = None, - keep_best_n_checkpoints: Optional[int] = None, - keep_last_n_models: Optional[int] = None, - keep_best_n_models: Optional[int] = None, - tb_dir: Optional[Path] = None, - metrics_dir: Optional[Path] = None, + checkpoint_every_n_steps: int | None = None, + checkpoint_after_n_data_epochs: int = 0, + checkpoint_every_n_data_epochs: int | None = None, + keep_last_n_checkpoints: int | None = None, + keep_best_n_checkpoints: int | None = None, + keep_last_n_models: int | None = None, + keep_best_n_models: int | None = None, + metric_recorders: Iterable[MetricRecorder] | None = None, + tb_dir: Path | None = None, + metrics_dir: Path | None = None, + wandb_options: tuple[Path, str, str] | None = None, publish_metrics_after_n_steps: int = 0, - publish_metrics_every_n_steps: int = 100, - profile: Optional[Tuple[int, int]] = None, + publish_metrics_every_n_steps: int | None = None, + publish_metrics_after_n_data_epochs: int = 0, + publish_metrics_every_n_data_epochs: int | None = None, + profile: tuple[int, int] | None = None, anomaly_detection: bool = False, seed: int = 2, ) -> None: @@ -218,13 +231,15 @@ def __init__( :param wall_watch: The stopwatch to track process wall-time. :param dtype: - The data type to train with. + The data type of the model. :param dp_gang: The data parallel gang. If ``None``, ``gang`` will be used. :param tp_gang: The tensor parallel gang. Only required for tensor parallel models. :param lr_scheduler: The learning rate scheduler. + :param amp: + If ``True``, enables ``torch.amp``. :param fp16_loss_scale: The initial and minimum loss scale for fp16 training. :param max_gradient_norm: @@ -247,10 +262,18 @@ def __init__( The number of steps after which to start validating the model. :param validate_every_n_steps: The step interval at which to validate the model. + :param validate_after_n_data_epochs: + The number of data epochs after which to start validating the model. + :param validate_every_n_data_epochs: + The data epoch interval at which to validate the model. :param checkpoint_after_n_steps: The number of steps after which to start checkpointing. :param checkpoint_every_n_steps: The step interval at which to checkpoint. + :param checkpoint_after_n_data_epochs: + The number of data epochs after which to start checkpointing. + :param checkpoint_every_n_data_epochs: + The data epoch interval at which to checkpoint. :param keep_last_n_checkpoints: The number of checkpoints to keep. If ``None``, none will be deleted. :param keep_best_n_checkpoints: @@ -263,14 +286,22 @@ def __init__( The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to ``keep_best_n_checkpoints``. + :param metric_recorders: + The metric recorders. :param tb_dir: - The TensorBoard log directory to dump metrics. + Legacy. Use ``metric_recorders``. :param metrics_dir: - The directory to dump metrics. + Legacy. Use ``metric_recorders``. + :param wandb_options: + Legacy. Use ``metric_recorders``. :param publish_metrics_after_n_steps: The number of steps after which to start publishing metrics. :param publish_metrics_every_n_steps: The step interval at which to publish metrics. + :param publish_metrics_after_n_data_epochs: + The number of data epochs after which to start publishing metrics. + :param publish_metrics_every_n_data_epochs: + The data epoch interval at which to publish metrics. :param profile: The number of steps that the PyTorch profiler should skip and then record. @@ -308,7 +339,8 @@ def __init__( self._dtype = dtype - if uses_fsdp := isinstance(self._model, FSDP): + uses_fsdp = isinstance(self._model, FSDP) + if uses_fsdp: self.register_stateful( "_optimizer", optimizer, FSDPOptimizerStateHandler(self._model) ) @@ -322,7 +354,7 @@ def __init__( self._loss_scaler = DynamicLossScaler( optimizer, root_gang, - sharded=uses_fsdp or self._tp_gang.size > 0, + sharded=uses_fsdp or self._tp_gang.size > 1, init_scale=fp16_init_scale, min_scale=fp16_min_scale, gradient_accumulation=self._data_reader.num_accumulate, @@ -331,17 +363,32 @@ def __init__( self._max_gradient_norm = max_gradient_norm + self._amp = amp + self.register_stateful("_step_nr", 0) + if max_num_steps is not None: + if max_num_steps <= 0: + raise ValueError("`max_num_steps` must be greater than zero.") + self._max_num_steps = max_num_steps self.register_stateful("_data_epoch_nr", 1) + if max_num_data_epochs is not None: + if max_num_data_epochs <= 0: + raise ValueError("`max_num_data_epochs` must be greater than zero.") + self._max_num_data_epochs = max_num_data_epochs - self._read_data = False + self._repeat_step = False + + self._read_data = False # Indicates whether we have read any data. - self._eod = max_num_data_epochs == 0 + self._num_effective_batches = 0 + + self._end_of_data_epoch = False + self._end_of_data = False self._should_stop = False @@ -356,7 +403,7 @@ def __init__( ) if root_gang.rank != 0: - early_stopper = lambda step_nr, score: False + early_stopper = NoopEarlyStopper() self._early_stopper = early_stopper else: @@ -384,38 +431,65 @@ def __init__( "`valid_units` and `valid_data_readers` must be both specified." ) - if validate_every_n_steps == 0: - raise ValueError("`validate_every_n_steps` must be greater than zero.") + if validate_every_n_steps is not None: + if validate_every_n_steps <= 0: + raise ValueError("`validate_every_n_steps` must be greater than zero.") self._validate_after_n_steps = validate_after_n_steps self._validate_every_n_steps = validate_every_n_steps + if validate_every_n_data_epochs is not None: + if validate_every_n_data_epochs <= 0: + raise ValueError( + "`validate_every_n_data_epochs` must be greater than zero." + ) + + self._validate_after_n_data_epochs = validate_after_n_data_epochs + self._validate_every_n_data_epochs = validate_every_n_data_epochs + self._checkpoint_manager = checkpoint_manager - if checkpoint_every_n_steps == 0: - raise ValueError("`checkpoint_every_n_steps` must be greater than zero.") + if checkpoint_every_n_steps is not None: + if checkpoint_every_n_steps <= 0: + raise ValueError( + "`checkpoint_every_n_steps` must be greater than zero." + ) self._checkpoint_after_n_steps = checkpoint_after_n_steps self._checkpoint_every_n_steps = checkpoint_every_n_steps - if keep_last_n_checkpoints is not None and keep_best_n_checkpoints is not None: - raise ValueError( - "`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time." - ) + if checkpoint_every_n_data_epochs is not None: + if checkpoint_every_n_data_epochs <= 0: + raise ValueError( + "`checkpoint_every_n_data_epochs` must be greater than zero." + ) + + self._checkpoint_after_n_data_epochs = checkpoint_after_n_data_epochs + self._checkpoint_every_n_data_epochs = checkpoint_every_n_data_epochs - if keep_last_n_checkpoints == 0: - raise ValueError("`keep_last_n_checkpoints` must be greater than zero.") + if keep_last_n_checkpoints is not None: + if keep_best_n_checkpoints is not None: + raise ValueError( + "`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time." + ) - if keep_best_n_checkpoints == 0: - raise ValueError("`keep_best_n_checkpoints` must be greater than zero.") + if keep_last_n_checkpoints <= 0: + raise ValueError("`keep_last_n_checkpoints` must be greater than zero.") + elif keep_best_n_checkpoints is not None: + if keep_best_n_checkpoints <= 0: + raise ValueError("`keep_best_n_checkpoints` must be greater than zero.") - if keep_best_n_checkpoints is not None: if checkpoint_every_n_steps is not None: if score_metric_name is None: raise ValueError( "`score_metric_name` must be specified when `keep_best_n_checkpoints` is specified." ) + if validate_every_n_steps is None: + raise ValueError( + "`validate_every_n_steps` must be specified when `keep_best_n_checkpoints` is specified." + ) + if checkpoint_every_n_steps % validate_every_n_steps != 0: raise ValueError( f"`checkpoint_every_n_steps` must be a multiple of `validate_every_n_steps` ({validate_every_n_steps}) when `keep_best_n_checkpoints` is specified, but is {checkpoint_every_n_steps} instead." @@ -455,16 +529,27 @@ def __init__( self._metric_bag = unit.metric_bag - if root_gang.rank == 0: - self._metric_recorders = [LogMetricRecorder(log)] + if metric_recorders is None: + # compat + if root_gang.rank == 0: + self._metric_recorders = [LogMetricRecorder(log)] - if tb_dir is not None: - self._metric_recorders.append(TensorBoardRecorder(tb_dir)) + if tb_dir is not None: + self._metric_recorders.append(TensorBoardRecorder(tb_dir)) - if metrics_dir is not None: - self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir)) + if metrics_dir is not None: + self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir)) + + if wandb_options is not None: + wandb_dir, wandb_project, wandb_name = wandb_options + + self._metric_recorders.append( + WandbRecorder(wandb_project, wandb_name, wandb_dir) + ) + else: + self._metric_recorders = [] else: - self._metric_recorders = [] + self._metric_recorders = list(metric_recorders) if publish_metrics_every_n_steps == 0: raise ValueError( @@ -474,20 +559,32 @@ def __init__( self._publish_metrics_after_n_steps = publish_metrics_after_n_steps self._publish_metrics_every_n_steps = publish_metrics_every_n_steps + if publish_metrics_every_n_data_epochs == 0: + raise ValueError( + "`publish_metrics_every_n_data_epochs` must be greater than zero." + ) + + self._publish_metrics_after_n_data_epochs = publish_metrics_after_n_data_epochs + self._publish_metrics_every_n_data_epochs = publish_metrics_every_n_data_epochs + if profile is None or tb_dir is None: if profile is not None and tb_dir is None: log.warning("No TensorBoard log directory provided. Profiling will be disabled.") # fmt: skip - skip_first, active_steps = 1, 0 + num_skip_steps, num_record_steps = 1, 0 profile_dir = Path() + + enabled = False else: - skip_first, active_steps = profile + num_skip_steps, num_record_steps = profile profile_dir = tb_dir + enabled = num_record_steps > 0 + self._profiler = Profiler( - skip_first, active_steps, profile_dir, root_gang, enabled=active_steps > 0 + num_skip_steps, num_record_steps, profile_dir, root_gang, enabled=enabled ) self._anomaly_detection = anomaly_detection @@ -498,7 +595,7 @@ def __init__( self._wall_watch = wall_watch - self._step_time = 0.0 + self._total_step_time = 0.0 self._run = False @@ -510,10 +607,9 @@ def request_stop(self) -> None: self._should_stop = True - @override def __call__(self) -> None: if self._run: - raise RuntimeError("The trainer can only be run once.") + raise InvalidOperationError("The trainer can only be run once.") self._run = True @@ -567,14 +663,16 @@ def _maybe_restore_state(self) -> None: def _do_run(self) -> None: with self._progress, self._profiler: - train_task = self._progress.add_task( + self._train_task_id = self._progress.add_task( "train", total=self._max_num_steps, completed=self._step_nr ) while self._should_run_step(): + self._maybe_advance_data_epoch() + self._step_nr += 1 - self._progress.update(train_task, advance=1) + self._progress.update(self._train_task_id, advance=1) detect_anomaly = torch.autograd.set_detect_anomaly( # type: ignore[attr-defined] self._anomaly_detection, check_nan=True @@ -582,10 +680,7 @@ def _do_run(self) -> None: with detect_anomaly: with record_function(f"step_{self._step_nr}"): - try: - self._run_step() - except StopIteration: - self._eod = True + self._run_step() if self._should_publish_metrics(): self._publish_metrics() @@ -603,7 +698,7 @@ def _do_run(self) -> None: self._valid_score = None def _should_run_step(self) -> bool: - if self._eod or self._should_stop: + if self._end_of_data or self._should_stop: return False if self._max_num_steps is None: @@ -611,136 +706,168 @@ def _should_run_step(self) -> bool: return self._step_nr < self._max_num_steps + def _maybe_advance_data_epoch(self) -> None: + if not self._end_of_data_epoch: + return + + self._data_epoch_nr += 1 + + self._end_of_data_epoch = False + def _run_step(self) -> None: step_nr = self._step_nr - log.debug("Running training step {}.", step_nr) - - stepped = False + log.debug("{} training step {}.", "Repeating" if self._repeat_step else "Running", step_nr) # fmt: skip watch = Stopwatch(start=True, device=self._root_gang.device) - with record_function(f"step_{step_nr}_prologue"): - self._unit.set_step_nr(step_nr) + # Collect the batches. + with record_function(f"step_{step_nr}_data_load"): + batches = self._next_batches() + if batches is None: + return - while not stepped: - # Collect the batches. - with record_function(f"step_{step_nr}_data_load"): - batches = self._next_batches() + # Prepare the unit. + if not self._repeat_step: + with record_function(f"step_{step_nr}_unit_setup"): + self._unit.set_step_nr(step_nr) - num_targets = 0 + num_targets = 0 - if self._loss_scaler.is_enabled: - self._metric_bag.begin_updates() + if self._loss_scaler.is_enabled: + self._metric_bag.begin_updates() + + # Accumulate. + for batch_nr, batch in enumerate(batches): + with self._maybe_no_sync(batch_nr, len(batches)): + with record_function(f"step_{step_nr}_{batch_nr}_forward"): + batch_loss, num_batch_targets = self._compute_loss(batch) - # Accumulate. - for batch_nr, batch in enumerate(batches): - with self._maybe_no_sync(batch_nr, len(batches)): - with record_function(f"step_{step_nr}_{batch_nr}_forward"): - batch_loss, num_batch_targets = self._compute_loss(batch) + if num_batch_targets is not None: + if num_batch_targets == 0: + raise ContractError( + "The train unit returned zero loss targets." + ) - with record_function(f"step_{step_nr}_{batch_nr}_backward"): - self._loss_scaler.backward(batch_loss) + num_targets += num_batch_targets - num_targets += num_batch_targets + with record_function(f"step_{step_nr}_{batch_nr}_backward"): + self._loss_scaler.backward(batch_loss) - # Normalize. + # Normalize. + if num_targets > 0: normalize_gradients(self._model, self._dp_gang, num_targets=num_targets) - # Clip. - with record_function(f"step_{step_nr}_grad_norm"): - self._loss_scaler.unscale_gradients_() + # Clip. + with record_function(f"step_{step_nr}_grad_norm"): + self._loss_scaler.unscale_gradients_() - # TODO(balioglu): Support tensor parallelism! - grad_norm = clip_gradient_norm( - self._model, max_norm=self._max_gradient_norm + # TODO(balioglu): Support tensor parallelism! + grad_norm = clip_gradient_norm( + self._model, max_norm=self._max_gradient_norm + ) + + # Sanity check. + if not check_gradient_norms(grad_norm, self._dp_gang, step_nr): + raise FloatingPointError( + f"The gradients are inconsistent between processes at step {step_nr}. Training cannot continue." ) - # Sanity check. - if not check_gradient_norms(grad_norm, self._dp_gang, step_nr): - raise FloatingPointError( - f"The gradients are inconsistent between processes at step {step_nr}. Training cannot continue." - ) + # Update the parameters. + with record_function(f"step_{step_nr}_optimizer"): + _, scale_result = self._loss_scaler.run_optimizer_step(step_nr) - # Update the parameters. - with record_function(f"step_{step_nr}_optimizer"): - _, scale_result = self._loss_scaler.run_optimizer_step(step_nr) + if scale_result.overflow: + self._metric_bag.rollback_updates() - if scale_result.overflow: - self._metric_bag.rollback_updates() + if scale_result.min_reached: + raise FloatingPointError( + f"The gradients are scaled down to minimum at step {step_nr}. Training cannot continue." + ) - if scale_result.min_reached: - raise FloatingPointError( - f"The gradients are scaled down to minimum at step {step_nr}. Training cannot continue." - ) + # Repeat the step with the next batch. + self._step_nr -= 1 - log.debug("Repeating step {}.", step_nr) - else: - self._lr_scheduler.step() + self._progress.update(self._train_task_id, advance=-1) - if self._loss_scaler.is_enabled: - self._metric_bag.commit_updates() + self._repeat_step = True + else: + self._lr_scheduler.step() - self._metric_bag.gradient_norm.update(grad_norm) + if self._loss_scaler.is_enabled: + self._metric_bag.commit_updates() - stepped = True + self._metric_bag.gradient_norm.update(grad_norm) - # Reset. - self._optimizer.zero_grad(set_to_none=True) + self._repeat_step = False - self._step_time += watch.get_elapsed_time() + self._num_effective_batches += 1 - def _next_batches(self) -> List[BatchT]: - while True: - try: - batches = next(self._data_reader) + # Reset. + self._optimizer.zero_grad(set_to_none=True) - self._read_data = True + self._total_step_time += watch.get_elapsed_time() - return batches - except StopIteration: - log.info("End of epoch {} reached at training step {}.", self._data_epoch_nr, self._step_nr) # fmt: skip + def _next_batches(self) -> list[BatchT] | None: + try: + batches = next(self._data_reader) + except StopIteration: + batches = None + + if batches is not None: + self._read_data = True + + return batches - if not self._read_data: # Means the dataset is empty. - break + self._data_reader.reset() - if self._max_num_data_epochs is not None: - if self._data_epoch_nr >= self._max_num_data_epochs: - break + self._end_of_data_epoch = True - self._data_epoch_nr += 1 + log.info("End of epoch {} reached at training step {}.", self._data_epoch_nr, self._step_nr) # fmt: skip - self._data_reader.reset() + if not self._read_data: # The dataset is empty. + self._end_of_data = True + elif self._max_num_data_epochs is not None: + if self._data_epoch_nr >= self._max_num_data_epochs: + self._end_of_data = True - log.info("End of data reached at training step {}.", self._step_nr) + if self._end_of_data: + log.info("End of data reached.", self._step_nr) - raise StopIteration() + # Repeat the step with the first batch of the next epoch. + self._step_nr -= 1 - def _maybe_no_sync(self, batch_nr: int, num_batches: int) -> ContextManager[None]: + self._progress.update(self._train_task_id, advance=-1) + + self._repeat_step = True + + return None + + def _maybe_no_sync( + self, batch_nr: int, num_batches: int + ) -> AbstractContextManager[None]: if batch_nr < num_batches - 1 and self._dp_gang.size > 1: return self._model.no_sync() # type: ignore[no-any-return] return nullcontext() - def _compute_loss(self, batch: BatchT) -> Tuple[Tensor, int]: + def _compute_loss(self, batch: BatchT) -> tuple[Tensor, int | None]: with self._maybe_autocast(): return self._unit(batch) - def _maybe_autocast(self) -> ContextManager[None]: - if self._dtype == torch.float32: + def _maybe_autocast(self) -> AbstractContextManager[None]: + if self._dtype == torch.float32 or not self._amp: return nullcontext() - if self._model.training and isinstance(self._model, (DDP, FSDP)): - if self._model.mixed_precision is not None: - return nullcontext() - return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype) def _should_publish_metrics(self) -> bool: - after_n_steps = self._publish_metrics_after_n_steps - every_n_steps = self._publish_metrics_every_n_steps - - return self._should_do(after_n_steps, every_n_steps) + return self._should_do( + self._publish_metrics_after_n_steps, + self._publish_metrics_every_n_steps, + self._publish_metrics_after_n_data_epochs, + self._publish_metrics_every_n_data_epochs, + ) def _publish_metrics(self) -> None: log.debug("Syncing metrics.") @@ -752,41 +879,45 @@ def _publish_metrics(self) -> None: self._metric_bag.reset_non_persistent_metrics() - elapsed_time = self._step_time + if self._root_gang.rank == 0: + if values is None: + raise InternalError("`values` is `None`.") - self._step_time = 0.0 + extend_batch_metrics( + values, self._num_effective_batches, self._total_step_time + ) - if self._root_gang.rank != 0: - return + values["lr"] = get_effective_lr(self._lr_scheduler) - assert values is not None + values["data_epoch"] = self._data_epoch_nr - values["lr"] = get_effective_lr(self._lr_scheduler) + values["elapsed_time"] = self._total_step_time - set_throughput_value(values, elapsed_time) + values["wall_time"] = self._wall_watch.get_elapsed_time() - values["elapsed_time"] = elapsed_time + record_metrics(self._metric_recorders, "train", values, self._step_nr) - values["wall_time"] = self._wall_watch.get_elapsed_time() + self._num_effective_batches = 0 - record_metrics(self._metric_recorders, "train", values, self._step_nr) + self._total_step_time = 0.0 def _should_validate(self) -> bool: if not self._valid_units: return False - after_n_steps = self._validate_after_n_steps - every_n_steps = self._validate_every_n_steps - - return self._should_do(after_n_steps, every_n_steps) + return self._should_do( + self._validate_after_n_steps, + self._validate_every_n_steps, + self._validate_after_n_data_epochs, + self._validate_every_n_data_epochs, + ) - @torch.inference_mode() def _validate(self) -> None: log.info("Starting validation after step {}.", self._step_nr) - with summon_fsdp_for_validation(self._model): - self._model.eval() + self._model.eval() + with summon_fsdp_for_validation(self._model): unit_scores = [] for unit, data_reader in zip(self._valid_units, self._valid_data_readers): @@ -799,21 +930,22 @@ def _validate(self) -> None: self._valid_score = self._compute_valid_score(unit_scores) - self._model.train() + self._model.train() log.info("Validation complete.") + @torch.inference_mode() def _validate_unit( self, unit: EvalUnit[BatchT], data_reader: DataReader[BatchT] - ) -> Optional[float]: + ) -> float | None: watch = Stopwatch(start=True, device=self._root_gang.device) - unit.model.eval() - unit.set_step_nr(self._step_nr) valid_task = self._progress.add_task("valid", total=None) + num_effective_batches = 0 + for step_nr in count(start=1): self._progress.update(valid_task, advance=1) @@ -822,33 +954,27 @@ def _validate_unit( try: batches = next(data_reader) except StopIteration: - batches = [] + break for batch in batches: with self._maybe_autocast(): unit(batch) - if self._is_valid_eod(batches): - break + num_effective_batches += 1 self._progress.remove_task(valid_task) data_reader.reset() - time = watch.get_elapsed_time() - - metric_values = self._publish_validation_metrics(unit, time) + metric_values = self._publish_validation_metrics( + unit, num_effective_batches, watch.get_elapsed_time() + ) return self._get_unit_score(metric_values) - def _is_valid_eod(self, batches: List[BatchT]) -> bool: - total_num_batches = all_sum(self._dp_gang, len(batches)) - - return bool(total_num_batches == 0) - def _publish_validation_metrics( - self, unit: EvalUnit[BatchT], elapsed_time: float - ) -> Optional[Dict[str, Any]]: + self, unit: EvalUnit[BatchT], num_batches: int, elapsed_time: float + ) -> dict[str, object] | None: log.debug("Syncing validation metrics.") if self._tp_gang.rank == 0: @@ -861,9 +987,12 @@ def _publish_validation_metrics( if self._root_gang.rank != 0: return None - assert values is not None + if values is None: + raise InternalError("`values` is `None`.") + + extend_batch_metrics(values, num_batches, elapsed_time) - set_throughput_value(values, elapsed_time) + values["data_epoch"] = self._data_epoch_nr values["elapsed_time"] = elapsed_time @@ -878,9 +1007,7 @@ def _publish_validation_metrics( return values - def _get_unit_score( - self, metric_values: Optional[Dict[str, Any]] - ) -> Optional[float]: + def _get_unit_score(self, metric_values: dict[str, object] | None) -> float | None: if metric_values is None: return None @@ -898,14 +1025,14 @@ def _get_unit_score( return float(score) - def _compute_valid_score(self, unit_scores: List[float]) -> Optional[float]: + def _compute_valid_score(self, unit_scores: list[float]) -> float | None: if self._score_metric_name is None: return None if not unit_scores: if self._root_gang.rank == 0: - raise RuntimeError( - "None of the validation units returned a score metric value. Please file a bug report to the recipe author." + raise ContractError( + "None of the validation units returned a score metric value." ) return None @@ -945,9 +1072,12 @@ def _maybe_request_early_stop(self) -> None: return if self._root_gang.rank == 0: - assert self._valid_score is not None + if self._valid_score is None: + raise InternalError("Early stopping, but `_valid_score` is `None`.") - should_stop = self._early_stopper(self._step_nr, self._valid_score) + should_stop = self._early_stopper.should_stop( + self._step_nr, self._valid_score + ) else: should_stop = False @@ -957,13 +1087,12 @@ def _maybe_request_early_stop(self) -> None: log.info("Early stop requested. Training will be terminated after saving checkpoint.") # fmt: skip def _should_checkpoint(self) -> bool: - after_n_steps = self._checkpoint_after_n_steps - every_n_steps = self._checkpoint_every_n_steps - - if every_n_steps is None: - return False - - return self._should_do(after_n_steps, every_n_steps) + return self._should_do( + self._checkpoint_after_n_steps, + self._checkpoint_every_n_steps, + self._checkpoint_after_n_data_epochs, + self._checkpoint_every_n_data_epochs, + ) def _checkpoint(self) -> None: step_nr = self._step_nr @@ -1003,7 +1132,7 @@ def _checkpoint(self) -> None: self._checkpoint_manager.save_consolidated_fsdp_model(self._model) - log.info("Consolidated model saved.") + log.info("Consolidated FSDP model saved.") self._checkpoint_manager.commit_checkpoint() @@ -1014,7 +1143,8 @@ def _checkpoint(self) -> None: nm = self._keep_last_n_models if nm is not None: - assert nc is not None + if nc is None: + raise InternalError("`_keep_last_n_checkpoints` is `None`") self._checkpoint_manager.keep_last_n_checkpoints(nm) self._checkpoint_manager.keep_last_n_checkpoints(nc, preserve_model=True) @@ -1025,22 +1155,52 @@ def _checkpoint(self) -> None: nm = self._keep_best_n_models if nm is not None: - assert nc is not None + if nc is None: + raise InternalError("`_keep_best_n_checkpoints` is `None`") - self._checkpoint_manager.keep_best_n_checkpoints(nm) - self._checkpoint_manager.keep_best_n_checkpoints(nc, preserve_model=True) + self._checkpoint_manager.keep_best_n_checkpoints( + nm, lower_better=self._lower_better + ) + self._checkpoint_manager.keep_best_n_checkpoints( + nc, lower_better=self._lower_better, preserve_model=True + ) elif nc is not None: - self._checkpoint_manager.keep_best_n_checkpoints(nc) + self._checkpoint_manager.keep_best_n_checkpoints( + nc, lower_better=self._lower_better + ) - def _should_do(self, after_n_steps: int, n_steps: int) -> bool: - if self._eod or self._should_stop: - return True + def _should_do( + self, + after_n_steps: int, + every_n_steps: int | None, + after_n_data_epochs: int, + every_n_data_epochs: int | None, + ) -> bool: + should_do_at_step = self._should_do_at_step(after_n_steps, every_n_steps) + + if self._end_of_data or self._should_stop: + if not self._read_data: + return False + + return not should_do_at_step - if self._step_nr < after_n_steps: + if self._end_of_data_epoch and every_n_data_epochs is not None: + if self._data_epoch_nr >= after_n_data_epochs: + if self._data_epoch_nr % every_n_data_epochs == 0: + return not should_do_at_step + + if self._repeat_step: return False + return should_do_at_step + + def _should_do_at_step(self, after_n_steps: int, every_n_steps: int | None) -> bool: if self._max_num_steps is not None: if self._step_nr >= self._max_num_steps: return True - return self._step_nr % n_steps == 0 + if every_n_steps is not None: + if self._step_nr >= after_n_steps: + return self._step_nr % every_n_steps == 0 + + return False diff --git a/src/fairseq2/recipes/utils/argparse.py b/src/fairseq2/recipes/utils/argparse.py index c2292bfc4..c66a0c7e3 100644 --- a/src/fairseq2/recipes/utils/argparse.py +++ b/src/fairseq2/recipes/utils/argparse.py @@ -7,7 +7,6 @@ from __future__ import annotations from argparse import ( - SUPPRESS, ZERO_OR_MORE, Action, ArgumentError, @@ -15,24 +14,22 @@ ArgumentTypeError, Namespace, ) -from typing import Any, Dict, List, Optional, final +from typing import Any, final import torch -import yaml -from yaml.parser import ParserError from fairseq2.typing import DataType +from fairseq2.utils.yaml import YamlError, read_yaml @final class ConfigAction(Action): - """Adds support for reading key-value pairs in format ``=``.""" + """ + Adds support for reading configuration key-value pairs in format ``=``. + """ def __init__( - self, - option_strings: List[str], - dest: str, - help: Optional[str] = None, + self, option_strings: list[str], dest: str, help: str | None = None ) -> None: super().__init__( option_strings, @@ -47,98 +44,100 @@ def __call__( parser: ArgumentParser, namespace: Namespace, values: Any, - option_string: Optional[str] = None, + option_string: str | None = None, ) -> None: - data: Dict[str, Any] = {} + data: dict[str, Any] = {} + + def get_parent_node(path: str) -> tuple[dict[str, Any], str]: + keys = path.split(".") + + node = data + + for key in keys[:-1]: + child_node = node.get(key) + + if not isinstance(child_node, dict): + child_node = {} + + node[key] = child_node + + node = child_node + + return node, keys[-1] for item in values: - key_value = item.split("=", maxsplit=1) - if len(key_value) != 2: - raise ArgumentError(self, f"invalid key-value pair: {item}") + item = item.strip() - key, value = [kv.strip() for kv in key_value] + if item.startswith("del:"): + path = item[4:] - try: - parsed_value = yaml.safe_load(value) - except ParserError: - raise ArgumentError( - self, f"invalid key-value pair: {item} (value must be yaml)" - ) + if "=" in path: + raise ArgumentError(self, f"key should not contain '=': {item}") - fields = key.split(".") + parent_node, key = get_parent_node(path) - if not all(f.isidentifier() for f in fields): - raise ArgumentError( - self, f"invalid key-value pair: {item} (key must be identifier)" - ) + del_keys = parent_node.get("_del_") - tmp = data + if not isinstance(del_keys, list): + del_keys = [] - for field in fields[:-1]: - try: - d = tmp[field] - except KeyError: - d = None + parent_node["_del_"] = del_keys - if not isinstance(d, dict): - d = {} + del_keys.append(key) + else: + path_value = item.split("=", maxsplit=1) + if len(path_value) != 2: + raise ArgumentError(self, f"invalid key-value pair: {item}") - tmp[field] = d + path, value = path_value - tmp = d + try: + parsed_value = read_yaml(value.lstrip()) + except YamlError: + raise ArgumentError( + self, f"invalid key-value pair: {item} (value must be yaml)" + ) - tmp[fields[-1]] = parsed_value + path = path.rstrip() - setattr(namespace, self.dest, data) + if path.startswith("add:"): + path = path[4:] + directive = "_add_" + elif path.startswith("set:"): + path = path[4:] -@final -class BooleanOptionalAction(Action): - """Adds support for reading boolean flags in format ``--, --no-``.""" + directive = "_set_" + else: + directive = "_set_" - def __init__( - self, - option_strings: List[str], - dest: str, - default: Any = None, - help: Optional[str] = None, - ) -> None: - all_option_strings = [] + parent_node, key = get_parent_node(path) - for option_string in option_strings: - all_option_strings.append(option_string) + directive_keys = parent_node.get(directive) - if option_string.startswith("--"): - all_option_strings.append(f"--no-{option_string[2:]}") + if not isinstance(directive_keys, dict): + directive_keys = {} - if help is not None: - if default is not None and default is not SUPPRESS: - help += " (default: %(default)s)" + parent_node[directive] = directive_keys - super().__init__( - all_option_strings, nargs=0, default=default, help=help, dest=dest - ) + directive_keys[key] = parsed_value - def __call__( - self, - parser: ArgumentParser, - namespace: Namespace, - values: Any, - option_string: Optional[str] = None, - ) -> None: - if option_string and option_string in self.option_strings: - setattr(namespace, self.dest, not option_string.startswith("--no-")) + items = getattr(namespace, self.dest, None) + if items is None: + items = [] + + items.append(data) - def format_usage(self) -> str: - return " | ".join(self.option_strings) + setattr(namespace, self.dest, items) def parse_dtype(value: str) -> DataType: - """Parse ``value`` as a ``torch.dtype``.""" if value.startswith("torch."): value = value[6:] - if isinstance(dtype := getattr(torch, value, None), DataType): - return dtype + dtype = getattr(torch, value, None) + + if not isinstance(dtype, DataType): + raise ArgumentTypeError("must be a `torch.dtype` identifier") - raise ArgumentTypeError("must be a `torch.dtype` identifier") + return dtype diff --git a/src/fairseq2/recipes/utils/asset.py b/src/fairseq2/recipes/utils/asset.py index 5b88de04a..c9ba43a28 100644 --- a/src/fairseq2/recipes/utils/asset.py +++ b/src/fairseq2/recipes/utils/asset.py @@ -7,9 +7,7 @@ from __future__ import annotations from pathlib import Path -from typing import Union - -from typing_extensions import TypeAlias +from typing import TypeAlias from fairseq2.assets import ( AssetCard, @@ -19,8 +17,9 @@ default_asset_store, load_metadata_file, ) +from fairseq2.utils.yaml import load_yaml -AssetReference: TypeAlias = Union[str, AssetCard, Path] +AssetReference: TypeAlias = str | AssetCard | Path def retrieve_asset_card(name_or_card: AssetReference) -> AssetCard: @@ -35,7 +34,7 @@ def retrieve_asset_card(name_or_card: AssetReference) -> AssetCard: if isinstance(name_or_card, Path): if name_or_card.is_dir(): raise AssetNotFoundError( - f"{name_or_card}", f"An asset metadata file cannot be found at {name_or_card}." # fmt: skip + name_or_card.name, f"An asset metadata file cannot be found at {name_or_card}." # fmt: skip ) return _card_from_file(name_or_card) @@ -60,7 +59,7 @@ def retrieve_asset_card(name_or_card: AssetReference) -> AssetCard: def _card_from_file(file: Path) -> AssetCard: - all_metadata = load_metadata_file(file) + all_metadata = load_metadata_file(file, load_yaml) if len(all_metadata) != 1: raise AssetMetadataError( @@ -71,12 +70,17 @@ def _card_from_file(file: Path) -> AssetCard: metadata["name"] = name - metadata_provider = InProcAssetMetadataProvider([metadata], name="argument") + metadata_provider = InProcAssetMetadataProvider([metadata]) # Strip the environment tag. name, _ = name.split("@", maxsplit=1) - return default_asset_store.retrieve_card(name, extra_provider=metadata_provider) + default_asset_store.user_metadata_providers.append(metadata_provider) + + try: + return default_asset_store.retrieve_card(name) + finally: + default_asset_store.user_metadata_providers.pop() def asset_as_path(name_or_card: AssetReference) -> Path: diff --git a/src/fairseq2/recipes/utils/environment.py b/src/fairseq2/recipes/utils/environment.py deleted file mode 100644 index 05b4e51ea..000000000 --- a/src/fairseq2/recipes/utils/environment.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import os -import subprocess -from abc import ABC, abstractmethod -from random import Random -from typing import AbstractSet, Dict, Type, final - -from fairseq2.typing import override - - -class EnvironmentSetter(ABC): - """Sets job environment variables.""" - - @abstractmethod - def set_torch_distributed_env(self) -> None: - """Set environment variables required to initialize ``torch.distributed``.""" - - @property - @abstractmethod - def cluster(self) -> str: - """The cluster type that this instance supports.""" - - -@final -class SlurmEnvironmentSetter(EnvironmentSetter): - """Sets job environment variables on a Slurm cluster.""" - - _job_id: int - - def __init__(self) -> None: - try: - job_id = os.environ["SLURM_JOB_ID"] - except KeyError: - raise RuntimeError( - "Slurm not detected. `SLURM_JOB_ID` environment variable cannot be found." - ) from None - - try: - self._job_id = int(job_id) - except ValueError as ex: - raise RuntimeError("Slurm job ID cannot be parsed.") from ex - - @override - def set_torch_distributed_env(self) -> None: - try: - os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"] - os.environ["RANK"] = os.environ["SLURM_PROCID"] - - try: - os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] - except KeyError: - os.environ["LOCAL_WORLD_SIZE"] = "1" - - os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] - - os.environ["MASTER_ADDR"] = self._get_master_addr() - os.environ["MASTER_PORT"] = self._get_master_port() - - os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"] - except KeyError as ex: - raise RuntimeError( - "Slurm job environment variables are not correctly set. If you are within an allocated job (i.e. `salloc`), make sure to run with `srun`. If you want to run without Slurm, use `--cluster none`." - ) from ex - - def _get_master_addr(self) -> str: - nodes = os.environ["SLURM_JOB_NODELIST"] - - result = subprocess.run( - ["scontrol", "show", "hostnames", nodes], capture_output=True, text=True - ) - - if result.returncode == 0: - if node_list := result.stdout.split("\n"): - return node_list[0] - - raise RuntimeError( - "The hostname or IP address of the Slurm node corresponding to rank 0 cannot be retrieved." - ) - - def _get_master_port(self) -> str: - try: - return os.environ["MASTER_PORT"] - except KeyError: - pass - - return str(Random(self._job_id).randint(20_000, 60_000)) - - @property - @override - def cluster(self) -> str: - return "slurm" - - -@final -class _NoneEnvironmentSetter(EnvironmentSetter): - @override - def set_torch_distributed_env(self) -> None: - return - - @property - @override - def cluster(self) -> str: - return "none" - - -class EnvironmentSetterRegistry: - """Holds cluster type to :class:`EnvironmentSetter` mappings.""" - - _types: Dict[str, Type[EnvironmentSetter]] - - def __init__(self) -> None: - self._types = {"slurm": SlurmEnvironmentSetter, "none": _NoneEnvironmentSetter} - - def get(self, cluster: str) -> EnvironmentSetter: - """Return the :class:`EnvironmentSetter` of the specified cluster type.""" - try: - kls = self._types[cluster] - except KeyError: - raise ValueError( - f"`cluster` must be a registered cluster name, but is '{cluster}' instead." - ) from None - - try: - return kls() - except TypeError as ex: - raise RuntimeError(f"`{kls}` has no default constructor.") from ex - - def get_for_inferred_cluster(self) -> EnvironmentSetter: - """Return the :class:`EnvironmentSetter` of the inferred cluster.""" - if "TORCHELASTIC_RUN_ID" in os.environ: # means we are in `torchrun`. - return self.get("none") - - for cluster, kls in self._types.items(): - if cluster == "none": - continue - - try: - return kls() - except RuntimeError: - pass - except TypeError as ex: - raise RuntimeError(f"`{kls}` has no default constructor.") from ex - - return self.get("none") - - def register(self, cluster: str, kls: Type[EnvironmentSetter]) -> None: - """Register a new :class:`EnvironmentSetter`. - - :param cluster: - The cluster type. - :param kls: - The :class:`EnvironmentSetter` subclass. Must have an ``__init__`` - method that takes no arguments. - """ - if cluster in self._types: - raise ValueError( - f"`cluster` must be a unique cluster name, but '{cluster}' has already a registered environment setter." - ) - - self._types[cluster] = kls - - def names(self) -> AbstractSet[str]: - """Return the supported cluster types.""" - return self._types.keys() - - -default_env_setters = EnvironmentSetterRegistry() diff --git a/src/fairseq2/recipes/utils/log.py b/src/fairseq2/recipes/utils/log.py index 3420485b0..be725ae48 100644 --- a/src/fairseq2/recipes/utils/log.py +++ b/src/fairseq2/recipes/utils/log.py @@ -9,83 +9,29 @@ import os import platform import socket -import sys -from contextlib import contextmanager -from logging import Logger -from pathlib import Path -from signal import SIG_DFL, SIGINT, raise_signal, signal -from typing import Iterator, Optional import fairseq2n import psutil import torch from rich.pretty import pretty_repr -from torch.cuda import OutOfMemoryError from torch.nn import Module import fairseq2 from fairseq2.logging import LogWriter from fairseq2.nn.utils.module import get_module_size -from fairseq2.typing import DataClass, Device -from fairseq2.utils.dataclass import dump_dataclass +from fairseq2.typing import Device -@contextmanager -def exception_logger(log: LogWriter) -> Iterator[None]: - """Log exceptions and CUDA OOM errors raised within the context.""" - try: - yield - except OutOfMemoryError: - s = torch.cuda.memory_summary() - - log.exception("CUDA run out of memory. See memory stats and exception details below.\n{}", s) # fmt: skip - - sys.exit(1) - except KeyboardInterrupt: - log.info("Command canceled!") - - signal(SIGINT, SIG_DFL) - - raise_signal(SIGINT) - except Exception: - log.exception("Command has failed. See exception details below.") - - sys.exit(1) - - -def log_config(config: DataClass, log: LogWriter, file: Optional[Path] = None) -> None: - """Log ``config``. - - :param config: - The config to log. - :param log: - The log to write to. - :param file: - The output file to write ``config`` in YAML format. - """ - if file is not None: - with file.open("w") as fp: - dump_dataclass(config, fp) - +def log_config(config: object, log: LogWriter) -> None: log.info("Config:\n{}", pretty_repr(config, max_width=88)) -def log_model_config(config: DataClass, log: LogWriter) -> None: - """Log ``config``. - - :param config: - The model config to log. - :param log: - The log to write to. - """ +def log_model_config(config: object, log: LogWriter) -> None: log.info("Model Config:\n{}", pretty_repr(config, max_width=88)) -def log_environment_info(log: LogWriter, device: Optional[Device] = None) -> None: +def log_environment_info(log: LogWriter, device: Device | None = None) -> None: """Log information about the host system and the installed software.""" - if isinstance(log, Logger): - log = LogWriter(log) - log_system_info(log, device) log_software_info(log, device) @@ -93,12 +39,12 @@ def log_environment_info(log: LogWriter, device: Optional[Device] = None) -> Non log_environment_variables(log) -def log_system_info(log: LogWriter, device: Optional[Device] = None) -> None: +def log_system_info(log: LogWriter, device: Device | None = None) -> None: """Log information about the host system.""" if not log.is_enabled_for_info(): return - def read_dist_name() -> Optional[str]: + def read_dist_name() -> str | None: try: fp = open("/etc/os-release") except OSError: @@ -196,14 +142,14 @@ def read_dist_name() -> Optional[str]: if device.type == "cpu": s = "CPU-only" elif device.type == "cuda": - pr = torch.cuda.get_device_properties(device) + props = torch.cuda.get_device_properties(device) s = ( f"ID: {device} | " - f"Name: {pr.name} | " - f"Memory: {pr.total_memory // (1024 * 1024):,} MiB | " - f"Number of SMs: {pr.multi_processor_count} | " - f"Compute Capability: {pr.major}.{pr.minor}" + f"Name: {props.name} | " + f"Memory: {props.total_memory // (1024 * 1024):,} MiB | " + f"Number of SMs: {props.multi_processor_count} | " + f"Compute Capability: {props.major}.{props.minor}" ) else: s = f"ID: {device}" @@ -211,7 +157,7 @@ def read_dist_name() -> Optional[str]: log.info("Device - {}", s) -def log_software_info(log: LogWriter, device: Optional[Device] = None) -> None: +def log_software_info(log: LogWriter, device: Device | None = None) -> None: """Log information about the installed software.""" if not log.is_enabled_for_info(): return @@ -259,7 +205,7 @@ def log_environment_variables(log: LogWriter) -> None: log.info("Environment Variables - {}", ", ".join(kv)) -def log_model(model: Module, log: LogWriter, *, rank: Optional[int] = None) -> None: +def log_model(model: Module, log: LogWriter, *, rank: int | None = None) -> None: """Log information about ``model``.""" if not log.is_enabled_for_info(): return diff --git a/src/fairseq2/recipes/utils/cli.py b/src/fairseq2/recipes/utils/rich.py similarity index 61% rename from src/fairseq2/recipes/utils/cli.py rename to src/fairseq2/recipes/utils/rich.py index 7430bf5c3..c1696e2e6 100644 --- a/src/fairseq2/recipes/utils/cli.py +++ b/src/fairseq2/recipes/utils/rich.py @@ -6,6 +6,8 @@ from __future__ import annotations +from rich import get_console as get_rich_console +from rich.console import Console from rich.progress import ( BarColumn, Progress, @@ -16,14 +18,47 @@ TimeRemainingColumn, ) from rich.text import Text +from typing_extensions import override -from fairseq2.console import get_error_console from fairseq2.gang import get_rank -from fairseq2.typing import override + +_console: Console | None = None + + +def get_console() -> Console: + global _console + + if _console is None: + _console = get_rich_console() + + return _console + + +def set_console(console: Console) -> None: + global _console + + _console = console + + +_error_console: Console | None = None + + +def get_error_console() -> Console: + global _error_console + + if _error_console is None: + _error_console = Console(stderr=True, highlight=False) + + return _error_console + + +def set_error_console(console: Console) -> None: + global _error_console + + _error_console = console def create_rich_progress() -> Progress: - """Create a :class:`Progress` instance to report job progress.""" console = get_error_console() columns = [ diff --git a/src/fairseq2/recipes/utils/setup.py b/src/fairseq2/recipes/utils/setup.py index 062bd7327..62e8a1050 100644 --- a/src/fairseq2/recipes/utils/setup.py +++ b/src/fairseq2/recipes/utils/setup.py @@ -7,7 +7,7 @@ from __future__ import annotations from datetime import timedelta -from typing import Any, Dict, Literal, Optional, Tuple, Type +from typing import Any, Literal import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -27,7 +27,7 @@ def setup_root_gang( log: LogWriter, *, - timeout: Optional[timedelta] = None, + timeout: timedelta | None = None, monitored: bool = False, ) -> Gang: """Set up the root gang. @@ -59,9 +59,9 @@ def setup_gangs( log: LogWriter, *, tp_size: int = 1, - timeout: Optional[timedelta] = None, + timeout: timedelta | None = None, monitored: bool = False, -) -> Tuple[Gang, Dict[str, Gang]]: +) -> tuple[Gang, dict[str, Gang]]: """Set up the root, data, and tensor parallel gangs. :param log: @@ -86,7 +86,7 @@ def setup_gangs( log.info("Data and tensor parallel gangs initialized.") - return root_gang, gangs + return root_gang, {"dp": gangs.dp, "tp": gangs.tp} def broadcast_model(model: Module, gang: Gang, log: LogWriter) -> None: @@ -192,12 +192,12 @@ def compile_model(model: Module, log: LogWriter, *, dynamic: bool = True) -> Mod ) -def check_model_type(model: Module, kls: Type[Module]) -> None: +def check_model_type(model: Module, kls: type[Module]) -> None: """Check if a potentially DDP or FSDP wrapped `model` is of type `kls`.""" if isinstance(model, (DDP, FSDP)): model = model.module if not isinstance(model, kls): raise ValueError( - f"The specified model must be of type `{kls}`, but is of type `{type(model)}` instead." + f"`model` must be of type `{kls}`, but is of type `{type(model)}` instead." ) diff --git a/src/fairseq2/recipes/utils/sweep.py b/src/fairseq2/recipes/utils/sweep.py deleted file mode 100644 index 98b59d1b4..000000000 --- a/src/fairseq2/recipes/utils/sweep.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import os -import re -from dataclasses import fields -from enum import Enum -from hashlib import sha1 -from typing import Any, Final, Mapping, Optional, Sequence, Set - -from fairseq2.typing import DataClass, DataType, is_dataclass_instance - - -class SweepTagger: - """Generates a sweep tag from the diff of two recipe configurations.""" - - _DEFAULT_ALLOW_SET: Final = { - "batch_shuffle_window", - "betas", - "data_parallelism", - "dataset", - "dtype", - "example_shuffle_window", - "final_lr_ratio", - "final_lr_scale", - "fp16_loss_scale", - "fsdp_reshard_after_forward", - "fsdp_wrap_granularity", - "gradient_accumulation", - "label_smoothing", - "lr", - "lr_stage_ratios", - "max_gradient_norm", - "max_num_elements", - "max_num_steps", - "max_num_tokens", - "max_seq_len", - "model", - "model_arch", - "model_config", - "num_lr_warmup_steps", - "pretrained_model", - "seed", - "split", - "start_lr", - "start_lr_scale", - "tensor_parallel_size", - "tokenizer", - "train_split", - "valid_split", - "weight_decay", - } - - def __init__(self, *, allow_set: Optional[Set[str]] = None) -> None: - """ - :param allow_set: - The configuration field names allowed while generating the sweep tag. - """ - if allow_set is None: - allow_set = self._DEFAULT_ALLOW_SET.copy() - - self._allow_set = allow_set - - def extend_allow_set(self, *extras: str) -> None: - """Extend the allowed configuration field names with ``extras``.""" - self._allow_set.update(extras) - - def __call__(self, preset: str, preset_config: DataClass, config: DataClass) -> str: - """ - :param preset: - The name of the preset recipe. - :param preset_config: - The preset (i.e. ground-truth) recipe configuration. - :param config: - The recipe configuration for which to generate a sweep tag. - """ - if type(config) is not type(preset_config): - raise ValueError( - f"`config` must be of the same type as `preset_config` (`{type(preset_config)}`), but is of type `{type(config)}` instead." - ) - - output = [f"preset_{self._remove_non_word(preset)}"] - - try: - world_size = os.environ["WORLD_SIZE"] - except KeyError: - world_size = "1" - - output.append(f"ws_{world_size}") - - def abbrv(s: str) -> str: - if s.startswith("num_"): - s = f"n_{s[4:]}" - - return s - - def generate(config: DataClass) -> None: - for field in fields(config): - value = getattr(config, field.name) - - if is_dataclass_instance(value): - generate(config) - elif field.name in self._allow_set: - if s := self._to_tag_value(value): - output.append(f"{abbrv(field.name)}_{s}") - - def generate_from_diff(preset_config: DataClass, config: DataClass) -> None: - for field in fields(config): - value = getattr(config, field.name) - - preset_value = getattr(preset_config, field.name) - - if is_dataclass_instance(preset_value): - if type(value) is type(preset_value): - generate_from_diff(preset_value, value) - else: - generate(value) - elif field.name in self._allow_set: - if preset_value == value: - continue - - if s := self._to_tag_value(value): - output.append(f"{abbrv(field.name)}_{s}") - - generate_from_diff(preset_config, config) - - s = ".".join(output) - - # Cap to maximum of 128 characters. - if len(s) > 128: - # Make sure we avoid name conflicts by prepending the hash of the - # whole tag to the truncated one. - s = s[:120] + self._hash(s) - - return s - - @classmethod - def _to_tag_value(cls, value: Any) -> Optional[str]: - s: Optional[str] - - if isinstance(value, str): - s = cls._remove_non_word(value) - - if len(s) < 16: - return s - - return cls._hash(s) - - if isinstance(value, bool): - return "t" if value else "f" - - if isinstance(value, (int, float)): - return f"{value}" - - if isinstance(value, DataType): - return f"{value}"[6:] - - if isinstance(value, Enum): - return value.name - - if isinstance(value, Sequence): - output = [] - - for v in value: - if s := cls._to_tag_value(v): - output.append(s) - - if not output: - return None - - s = "-".join(output) - - return f"b{s}e" - - if isinstance(value, Mapping): - output = [] - - for k, v in value.items(): - ks = cls._to_tag_value(k) - vs = cls._to_tag_value(v) - - if ks and vs: - output.append(f"{ks}_{vs}") - - if not output: - return None - - output.sort() - - s = "-".join(output) - - return f"b{s}e" - - return None - - @staticmethod - def _remove_non_word(s: str) -> str: - return re.sub(r"[^-_\w]", "", s) - - @staticmethod - def _hash(s: str) -> str: - s = sha1(s.encode("utf-8")).hexdigest() - - return s[:8] - - -default_sweep_tagger = SweepTagger() diff --git a/src/fairseq2/recipes/utils/sweep_tagger.py b/src/fairseq2/recipes/utils/sweep_tagger.py new file mode 100644 index 000000000..ad2a4d03e --- /dev/null +++ b/src/fairseq2/recipes/utils/sweep_tagger.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from collections.abc import Hashable, Iterable, Sequence, Set +from enum import Enum +from hashlib import sha1 +from typing import final + +from typing_extensions import override + +from fairseq2.utils.dataclass import EMPTY +from fairseq2.utils.structured import StructureError + + +class SweepTagger(ABC): + """Generates a sweep tag from a recipe configuration.""" + + @abstractmethod + def generate( + self, + world_size: int, + preset: str, + unstructured_config: object, + fmt: str | None = None, + ) -> str: + ... + + +@final +class StandardSweepTagger(SweepTagger): + _allowed_keys: Set[Hashable] + + def __init__(self, allowed_keys: Set[Hashable]) -> None: + """ + :param allowed_keys: The recipe configuration keys allowed to be used in + sweep tags. + """ + self._allowed_keys = allowed_keys + + @override + def generate( + self, + world_size: int, + preset: str, + unstructured_config: object, + fmt: str | None = None, + ) -> str: + if fmt is None: + fmt = "ps_{preset}.ws_{world_size}.{hash}" + else: + fmt = fmt.strip() + if not fmt: + raise SweepFormatError("`fmt` must not be empty.") + + tags = {"preset": preset, "world_size": f"{world_size}"} + + self._collect_tags(unstructured_config, tags, path="") + + tags["hash"] = self._generate_hash(tags) + + return self._safe_format(fmt, tags) + + def _collect_tags(self, obj: object, tags: dict[str, str], path: str) -> None: + if obj is None: + tags[path] = "none" + + return + + if obj is EMPTY: + tags[path] = "empty" + + return + + if isinstance(obj, str): + tag = self._remove_non_word(obj) + + if len(tag) >= 16: + tag = self._generate_tag_hash(tag) + + tags[path] = tag + + return + + if isinstance(obj, bool): + tags[path] = "t" if obj else "f" + + return + + if isinstance(obj, int | float): + tags[path] = f"{obj}" + + return + + if isinstance(obj, list): + for idx, elem in enumerate(obj): + self._collect_tags(elem, tags, path=f"{path}[{idx}]") + + return + + if isinstance(obj, dict): + for key, value in obj.items(): + if key in self._allowed_keys: + self._collect_tags( + value, tags, path=f"{path}.{key}" if path else f"{key}" + ) + + return + + raise StructureError( + "`unstructured_config` must be of a composition of types `bool`, `int`, `float`, `str`, `list`, and `dict`." + ) + + @staticmethod + def _remove_non_word(s: str) -> str: + return re.sub(r"[^-_\w]", "", s) + + @staticmethod + def _generate_tag_hash(s: str) -> str: + algo = sha1(s.encode("utf-8")) + + h = algo.hexdigest() + + return h[:8] + + @staticmethod + def _generate_hash(tags: dict[str, str]) -> str: + algo = sha1() + + for k, v in sorted(tags.items()): + algo.update(k.encode("utf-8")) + algo.update(v.encode("utf-8")) + + h = algo.hexdigest() + + return h[:8] + + @staticmethod + def _safe_format(fmt: str, tags: dict[str, str]) -> str: + class State(Enum): + LITERAL = 0 + PLACEHOLDER = 1 + OPENING_BRACE = 2 + CLOSING_BRACE = 3 + + output = [] + + placeholder: list[str] = [] + + unknown_keys: set[str] = set() + + state = State.LITERAL + + for c in fmt: + match state: + case State.LITERAL: + if c == "{": + state = State.OPENING_BRACE + elif c == "}": + state = State.CLOSING_BRACE + else: + output.append(c) + case State.OPENING_BRACE: + if c == "{": # escape + state = State.LITERAL + + output.append("{") + elif c == "}": + raise SweepFormatError( + "`fmt` must not have any empty placeholders" + ) + else: + state = State.PLACEHOLDER + + placeholder.append(c) + case State.PLACEHOLDER: + if c == "}": + state = State.LITERAL + + key = "".join(placeholder) + + tag: Iterable[str] | None = tags.get(key) + if tag is None: + tag = placeholder + + unknown_keys.add(key) + + output.extend(tag) + + placeholder.clear() + else: + placeholder.append(c) + case State.CLOSING_BRACE: + state = State.LITERAL + + if c == "}": # escape + output.append("}") + else: + output.append(c) + + if state != State.LITERAL: + raise SweepFormatError( + "`fmt` must have matching opening and closing braces." + ) + + if unknown_keys: + keys = list(unknown_keys) + + keys.sort() + + s = ", ".join(keys) + + raise SweepFormatPlaceholderError( + keys, f"`fmt` must contain only placeholders that correspond to the configuration keys, but contains the following unexpected placeholder(s): {s}" # fmt: skip + ) + + return "".join(output) + + +@final +class NoopSweepTagger(SweepTagger): + @override + def generate( + self, + world_size: int, + preset: str, + unstructured_config: object, + fmt: str | None = None, + ) -> str: + return "" + + +class SweepFormatError(ValueError): + pass + + +class SweepFormatPlaceholderError(SweepFormatError): + unknown_keys: Sequence[str] + + def __init__(self, unknown_keys: Sequence[str], message: str) -> None: + super().__init__(message) + + self.unknown_keys = unknown_keys diff --git a/src/fairseq2/recipes/wav2vec2/__init__.py b/src/fairseq2/recipes/wav2vec2/__init__.py index e69de29bb..5a49dbe4c 100644 --- a/src/fairseq2/recipes/wav2vec2/__init__.py +++ b/src/fairseq2/recipes/wav2vec2/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.recipes.cli import Cli, RecipeCommandHandler +from fairseq2.recipes.wav2vec2.eval import ( + load_wav2vec2_evaluator, + wav2vec2_eval_presets, +) +from fairseq2.recipes.wav2vec2.train import ( + load_wav2vec2_trainer, + wav2vec2_train_presets, +) + + +def _setup_wav2vec2_cli(cli: Cli) -> None: + extra_sweep_keys = {"max_audio_len", "min_audio_len", "normalize_audio"} + + group = cli.add_group("wav2vec2", help="wav2vec 2.0 pretraining recipes") + + # Train + train_handler = RecipeCommandHandler( + loader=load_wav2vec2_trainer, + preset_configs=wav2vec2_train_presets, + default_preset="base_960h", + extra_sweep_keys=extra_sweep_keys, + ) + + group.add_command( + name="train", + handler=train_handler, + help="train a wav2vec 2.0 model", + ) + + # Eval + eval_handler = RecipeCommandHandler( + loader=load_wav2vec2_evaluator, + preset_configs=wav2vec2_eval_presets, + default_preset="base_ls960h", + extra_sweep_keys=extra_sweep_keys, + ) + + group.add_command( + name="eval", + handler=eval_handler, + help="evaluate a wav2vec 2.0 model", + ) diff --git a/src/fairseq2/recipes/wav2vec2/asr/__init__.py b/src/fairseq2/recipes/wav2vec2/asr/__init__.py index ce03ae689..0f794ed16 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/__init__.py +++ b/src/fairseq2/recipes/wav2vec2/asr/__init__.py @@ -7,7 +7,6 @@ from __future__ import annotations from fairseq2.recipes.cli import Cli, RecipeCommandHandler -from fairseq2.recipes.utils.sweep import default_sweep_tagger from fairseq2.recipes.wav2vec2.asr.eval import ( load_wav2vec2_asr_evaluator, wav2vec2_asr_eval_presets, @@ -19,12 +18,12 @@ def _setup_wav2vec2_asr_cli(cli: Cli) -> None: - default_sweep_tagger.extend_allow_set( + extra_sweep_keys = { "freeze_encoder_for_n_steps", "max_audio_len", "min_audio_len", "normalize_audio", - ) + } group = cli.add_group("wav2vec2_asr", help="wav2vec 2.0 ASR recipes") @@ -33,6 +32,7 @@ def _setup_wav2vec2_asr_cli(cli: Cli) -> None: loader=load_wav2vec2_asr_trainer, preset_configs=wav2vec2_asr_train_presets, default_preset="base_10h", + extra_sweep_keys=extra_sweep_keys, ) group.add_command( @@ -46,6 +46,7 @@ def _setup_wav2vec2_asr_cli(cli: Cli) -> None: loader=load_wav2vec2_asr_evaluator, preset_configs=wav2vec2_asr_eval_presets, default_preset="base_10h", + extra_sweep_keys=extra_sweep_keys, ) group.add_command( diff --git a/src/fairseq2/recipes/wav2vec2/asr/common.py b/src/fairseq2/recipes/wav2vec2/asr/common.py index 973339f92..a917e2e93 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/common.py +++ b/src/fairseq2/recipes/wav2vec2/asr/common.py @@ -7,59 +7,178 @@ from __future__ import annotations import math +from typing import Any, TextIO, final import torch from torch import Tensor +from torch.nn import Module +from typing_extensions import override +from fairseq2.data.text import TextTokenDecoder, TextTokenizer from fairseq2.gang import Gang from fairseq2.metrics.aggregation import Mean +from fairseq2.metrics.text import WerMetric from fairseq2.models.seq2seq import Seq2SeqBatch -from fairseq2.recipes.common_metrics import TaskMetricBag +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel, Wav2Vec2AsrOutput +from fairseq2.recipes.common_metrics import BaseMetricBag +from fairseq2.recipes.utils.setup import check_model_type -class Wav2Vec2AsrMetricBag(TaskMetricBag): - """Holds the metrics of a wav2vec 2.0 ASR model training or evaluation task.""" +@final +class Wav2Vec2AsrCriterion: + _model: Module + _scorer: Wav2Vec2AsrScorer | None - _ctc_loss: Mean + def __init__(self, model: Module, scorer: Wav2Vec2AsrScorer | None = None) -> None: + check_model_type(model, Wav2Vec2AsrModel) + + self._model = model + + self._scorer = scorer + + def __call__( + self, batch: Seq2SeqBatch, metric_bag: Wav2Vec2AsrMetricBag + ) -> tuple[Tensor, int]: + input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) + + output = self._forward(input_batch) + + loss = output.compute_loss(batch.target_seqs, batch.target_padding_mask) + + metric_bag.update_ctc_loss(batch, loss) + + metric_bag.update_batch_metrics(batch) + + if self._scorer is not None: + self._scorer(batch, output, metric_bag) + + return loss, batch.batch_size + + def _forward(self, batch: SequenceBatch) -> Wav2Vec2AsrOutput: + return self._model(batch) # type: ignore[no-any-return] + + @property + def model(self) -> Module: + return self._model + + +@final +class Wav2Vec2AsrScorer: + _text_decoder: TextTokenDecoder + _pad_idx: int + _blank_label: int + _ref_output_stream: TextIO | None + _hyp_output_stream: TextIO | None + + def __init__( + self, + tokenizer: TextTokenizer, + *, + blank_label: int = 0, + ref_output_stream: TextIO | None = None, + hyp_output_stream: TextIO | None = None, + ) -> None: + """ + :param tokenizer: The tokenizer to encode target text. + :param blank_label: The blank label in logits. + :param ref_output_stream: The output stream to dump references. + :param hyp_output_stream: The output stream to dump hypotheses. + """ + self._text_decoder = tokenizer.create_decoder() + + pad_idx = tokenizer.vocab_info.pad_idx + if pad_idx is None: + raise ValueError( + "``vocab_info` of `tokenizer` must have a PAD symbol defined." + ) + + self._pad_idx = pad_idx + + self._blank_label = blank_label + + self._ref_output_stream = ref_output_stream + self._hyp_output_stream = hyp_output_stream + + def __call__( + self, + batch: Seq2SeqBatch, + output: Wav2Vec2AsrOutput, + metric_bag: Wav2Vec2AsrMetricBag, + ) -> None: + # (N, S), (N, S) + ref_seqs, ref_padding_mask = batch.target_seqs, batch.target_padding_mask + + # (N, S), (N, S) + hyp_seqs, hyp_padding_mask = output.generate_hypotheses( + self._pad_idx, self._blank_label + ) + + refs = [self._text_decoder(s) for s in ref_seqs] + hyps = [self._text_decoder(s) for s in hyp_seqs] + + metric_bag.wer.update( + refs, ref_seqs, ref_padding_mask, hyps, hyp_seqs, hyp_padding_mask + ) + + # Dump references. + if stream := self._ref_output_stream: + for ref in refs: + stream.write(ref) + stream.write("\n") + + stream.flush() + + # Dump hypotheses. + if stream := self._hyp_output_stream: + for hyp in hyps: + stream.write(hyp) + stream.write("\n") + + stream.flush() + + +class Wav2Vec2AsrMetricBag(BaseMetricBag): + ctc_loss: Mean + wer: WerMetric def __init__(self, gang: Gang, train: bool = True) -> None: super().__init__(gang, train=train) d = gang.device - self.register_metric("_ctc_loss", Mean(device=d), persistent=False) + self.register_metric("ctc_loss", Mean(device=d), persistent=False) + + self.register_metric("wer", WerMetric(device=d), persistent=False) @torch.inference_mode() def update_ctc_loss(self, batch: Seq2SeqBatch, loss: Tensor) -> None: - """Update the CTC loss metric. - - :param batch: - The batch processed by the model. - :param ctc_loss: - The loss of ``batch``. - """ - normalized_loss = loss / batch.batch_size / math.log(2) + n = batch.batch_size - self._ctc_loss.update(normalized_loss, weight=batch.batch_size) + self.ctc_loss.update(loss.detach() / n / math.log(2), weight=n) @torch.inference_mode() def update_batch_metrics(self, batch: Seq2SeqBatch) -> None: - """Update the batch metrics. - - :param batch: - The batch processed by the model. - """ num_examples = batch.batch_size - num_elements = batch.num_source_elements() - self._num_batches.update(1) + num_elements = batch.num_source_elements() - self._num_examples.update(num_examples) - self._num_elements.update(num_elements) + self.num_examples.update(num_examples) + self.num_elements.update(num_elements) if self._train: - assert self._total_num_examples is not None - assert self._total_num_elements is not None + assert self.total_num_examples is not None + assert self.total_num_elements is not None + + self.total_num_examples.update(num_examples) + self.total_num_elements.update(num_elements) + + @override + def process_metric_values(self, values: dict[str, Any]) -> None: + super().process_metric_values(values) + + uer, wer = values.pop("wer") - self._total_num_examples.update(num_examples) - self._total_num_elements.update(num_elements) + if uer >= 0.0 and wer >= 0.0: + values["uer"] = uer + values["wer"] = wer diff --git a/src/fairseq2/recipes/wav2vec2/asr/eval.py b/src/fairseq2/recipes/wav2vec2/asr/eval.py index d794913ed..ddb2db985 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/eval.py +++ b/src/fairseq2/recipes/wav2vec2/asr/eval.py @@ -8,25 +8,22 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional, TextIO, final +from typing import final import torch -from torch.nn import Module +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider from fairseq2.config_registry import ConfigRegistry -from fairseq2.data.text import TextTokenDecoder, TextTokenizer, load_text_tokenizer +from fairseq2.data.text import load_text_tokenizer from fairseq2.datasets import LengthBatching -from fairseq2.datasets.asr import GenericAsrDataset, load_asr_dataset +from fairseq2.datasets.asr import AsrReadOptions, GenericAsrDataset, load_asr_dataset from fairseq2.gang import Gang from fairseq2.logging import get_log_writer -from fairseq2.metrics.text import WerMetric from fairseq2.models import load_model from fairseq2.models.seq2seq import Seq2SeqBatch -from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel -from fairseq2.models.wav2vec2.asr.model import Wav2Vec2AsrOutput from fairseq2.nn.utils.module import remove_parametrizations from fairseq2.recipes.evaluator import AbstractEvalUnit, Evaluator from fairseq2.recipes.utils.asset import ( @@ -35,19 +32,19 @@ retrieve_asset_card, ) from fairseq2.recipes.utils.log import log_model -from fairseq2.recipes.utils.setup import ( - broadcast_model, - check_model_type, - setup_root_gang, +from fairseq2.recipes.utils.setup import broadcast_model, setup_root_gang +from fairseq2.recipes.wav2vec2.asr.common import ( + Wav2Vec2AsrCriterion, + Wav2Vec2AsrMetricBag, + Wav2Vec2AsrScorer, ) -from fairseq2.recipes.wav2vec2.asr.common import Wav2Vec2AsrMetricBag -from fairseq2.typing import META, DataType, override +from fairseq2.typing import META, DataType from fairseq2.utils.profiler import Stopwatch log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class Wav2Vec2AsrEvalConfig: """Holds the configuration of a wav2vec 2.0 ASR model evaluation task.""" @@ -77,12 +74,15 @@ class Wav2Vec2AsrEvalConfig: model: AssetReference = "wav2vec2_asr_base_10h" """The name or path to the asset card of the wav2vec 2.0 ASR model to evaluate.""" - checkpoint_dir: Optional[Path] = None + checkpoint_dir: Path | None = None """The checkpoint directory containing models saved by :class:`FileCheckpointManager`.""" dtype: DataType = torch.float16 """The data type of the model.""" + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + # Misc seed: int = 2 """The random number generator seed to use.""" @@ -98,6 +98,7 @@ def _base_10h() -> Wav2Vec2AsrEvalConfig: return Wav2Vec2AsrEvalConfig() +@torch.inference_mode() def load_wav2vec2_asr_evaluator( config: Wav2Vec2AsrEvalConfig, output_dir: Path ) -> Evaluator[Seq2SeqBatch]: @@ -106,15 +107,11 @@ def load_wav2vec2_asr_evaluator( if config.checkpoint_dir is not None: default_asset_store.metadata_providers.append( - CheckpointModelMetadataProvider( - config.checkpoint_dir, lower_score_better=True - ) + CheckpointModelMetadataProvider(config.checkpoint_dir) ) gang = setup_root_gang(log) - seed = config.seed - model_card = retrieve_asset_card(config.model) # Load the tokenizer. @@ -156,7 +153,10 @@ def load_wav2vec2_asr_evaluator( "The model cannot be initialized. See nested exception for details." ) from ex - check_model_type(model, Wav2Vec2AsrModel) + if not isinstance(model, Wav2Vec2AsrModel): + raise ValueError( + f"The model must be of type `{Wav2Vec2AsrModel}`, but is of type `{type(model)}` instead." + ) gang.barrier() @@ -170,7 +170,7 @@ def load_wav2vec2_asr_evaluator( log_model(model, log) - # Initialize the evaluation unit. + # Initialize the criterion. ref_output_file = output_dir.joinpath(f"transcriptions/rank_{gang.rank}.ref.txt") hyp_output_file = output_dir.joinpath(f"transcriptions/rank_{gang.rank}.hyp.txt") @@ -195,28 +195,40 @@ def load_wav2vec2_asr_evaluator( f"The output file '{hyp_output_file}' cannot be created. See nested exception for details." ) from ex - unit = Wav2Vec2AsrEvalUnit( - model, - gang, - tokenizer, - ref_output_stream=ref_output_fp, - hyp_output_stream=hyp_output_fp, + scorer = Wav2Vec2AsrScorer( + tokenizer, ref_output_stream=ref_output_fp, hyp_output_stream=hyp_output_fp ) - data_reader = dataset.create_reader( - config.split, - tokenizer, - gang, - batching=LengthBatching(config.max_num_elements), + criterion = Wav2Vec2AsrCriterion(model, scorer) + + # Initialize the unit. + unit = Wav2Vec2AsrEvalUnit(criterion, gang) + + seed = config.seed + + options = AsrReadOptions( dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, normalize_audio=config.normalize_audio, - sync_batches=False, + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed, ) + try: + data_reader = dataset.create_reader( + config.split, + tokenizer, + gang, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + seed += 1 # Initialize the evaluator. @@ -224,6 +236,8 @@ def load_wav2vec2_asr_evaluator( units=[unit], data_readers=[data_reader], root_gang=gang, + dtype=config.dtype, + amp=config.amp, tb_dir=output_dir.joinpath("tb"), metrics_dir=output_dir.joinpath("metrics"), seed=seed, @@ -233,134 +247,21 @@ def load_wav2vec2_asr_evaluator( @final class Wav2Vec2AsrEvalUnit(AbstractEvalUnit[Seq2SeqBatch]): - """Represents a wav2vec 2.0 ASR model evaluation unit.""" - - _text_decoder: TextTokenDecoder - _pad_idx: int - _blank_label: int - _ref_output_stream: Optional[TextIO] - _hyp_output_stream: Optional[TextIO] - _metric_bag: Wav2Vec2AsrEvalMetricBag - - def __init__( - self, - model: Module, - gang: Gang, - tokenizer: TextTokenizer, - *, - blank_label: int = 0, - ref_output_stream: Optional[TextIO] = None, - hyp_output_stream: Optional[TextIO] = None, - ) -> None: - """ - :param model: - The wav2vec 2.0 ASR model. Might be wrapped with DDP or FSDP. - :param gang: - The gang for distributed evaluation. - :param tokenizer: - The tokenizer to encode target text. - :param blank_label: - The blank label in logits. - :param ref_output_stream: - The output stream to dump references. - :param hyp_output_stream: - The output stream to dump hypotheses. - """ - super().__init__(model) - - check_model_type(model, Wav2Vec2AsrModel) - - self._text_decoder = tokenizer.create_decoder() - - pad_idx = tokenizer.vocab_info.pad_idx - if pad_idx is None: - raise ValueError( - "``vocab_info` of `tokenizer` must have a PAD symbol defined." - ) - - self._pad_idx = pad_idx - - self._blank_label = blank_label - - self._ref_output_stream = ref_output_stream - self._hyp_output_stream = hyp_output_stream - - self._metric_bag = Wav2Vec2AsrEvalMetricBag(gang) - - @override - def __call__(self, batch: Seq2SeqBatch) -> None: - input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) - - output = self._forward(input_batch) - - loss = output.compute_loss(batch.target_seqs, batch.target_padding_mask) + _criterion: Wav2Vec2AsrCriterion + _metric_bag: Wav2Vec2AsrMetricBag - self._metric_bag.update_ctc_loss(batch, loss.detach()) + def __init__(self, criterion: Wav2Vec2AsrCriterion, gang: Gang) -> None: + super().__init__(criterion.model) - self._metric_bag.update_batch_metrics(batch) + self._criterion = criterion - self._compute_wer(batch, output) + self._metric_bag = Wav2Vec2AsrMetricBag(gang, train=False) - def _compute_wer(self, batch: Seq2SeqBatch, output: Wav2Vec2AsrOutput) -> None: - # (N, S), (N, S) - ref_seqs, ref_padding_mask = batch.target_seqs, batch.target_padding_mask - - # (N, S), (N, S) - hyp_seqs, hyp_padding_mask = output.generate_hypotheses( - self._pad_idx, self._blank_label - ) - - refs = [self._text_decoder(s) for s in ref_seqs] - hyps = [self._text_decoder(s) for s in hyp_seqs] - - self._metric_bag.wer.update( - refs, ref_seqs, ref_padding_mask, hyps, hyp_seqs, hyp_padding_mask - ) - - # Dump references. - if stream := self._ref_output_stream: - for ref in refs: - stream.write(ref) - stream.write("\n") - - stream.flush() - - # Dump hypotheses. - if stream := self._hyp_output_stream: - for hyp in hyps: - stream.write(hyp) - stream.write("\n") - - stream.flush() - - def _forward(self, batch: SequenceBatch) -> Wav2Vec2AsrOutput: - return self._model(batch) # type: ignore[no-any-return] + @override + def __call__(self, batch: Seq2SeqBatch) -> None: + self._criterion(batch, self._metric_bag) @property @override - def metric_bag(self) -> Wav2Vec2AsrEvalMetricBag: + def metric_bag(self) -> Wav2Vec2AsrMetricBag: return self._metric_bag - - -class Wav2Vec2AsrEvalMetricBag(Wav2Vec2AsrMetricBag): - """Holds the metrics of a wav2vec 2.0 ASR model evaluation task.""" - - wer: WerMetric - - def __init__(self, gang: Gang) -> None: - """ - :param gang: - The gang over which to sync metrics. - """ - super().__init__(gang, train=False) - - self.register_metric("wer", WerMetric(device=gang.device), persistent=False) - - @override - def process_metric_values(self, values: Dict[str, Any]) -> None: - super().process_metric_values(values) - - uer, wer = values.pop("wer") - - values["uer"] = uer - values["wer"] = wer diff --git a/src/fairseq2/recipes/wav2vec2/asr/train.py b/src/fairseq2/recipes/wav2vec2/asr/train.py index 79c5a0139..9dae58d9d 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/train.py +++ b/src/fairseq2/recipes/wav2vec2/asr/train.py @@ -6,30 +6,29 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Literal, Optional, Tuple, final +from typing import Any, Literal, final import torch from torch import Tensor -from torch.nn import Module +from typing_extensions import override from fairseq2.assets import AssetNotFoundError, default_asset_store from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager from fairseq2.config_registry import ConfigRegistry from fairseq2.data.text import load_text_tokenizer from fairseq2.datasets import LengthBatching -from fairseq2.datasets.asr import GenericAsrDataset, load_asr_dataset +from fairseq2.datasets.asr import AsrReadOptions, GenericAsrDataset, load_asr_dataset from fairseq2.gang import Gang from fairseq2.logging import get_log_writer from fairseq2.models import create_model from fairseq2.models.seq2seq import Seq2SeqBatch -from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2 import load_wav2vec2_model -from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel, Wav2Vec2AsrOutput +from fairseq2.models.wav2vec2.asr import Wav2Vec2AsrModel from fairseq2.nn.utils.module import freeze_parameters, share_parameters, to_device -from fairseq2.optim import AdamW -from fairseq2.optim.lr_scheduler import TriStageLR +from fairseq2.optim import AdamWConfig, create_optimizer +from fairseq2.optim.lr_scheduler import TriStageLRConfig, create_lr_scheduler from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer from fairseq2.recipes.utils.asset import ( AssetReference, @@ -38,20 +37,24 @@ ) from fairseq2.recipes.utils.log import log_model, log_model_config from fairseq2.recipes.utils.setup import ( - check_model_type, compile_model, setup_root_gang, to_data_parallel, ) -from fairseq2.recipes.wav2vec2.asr.common import Wav2Vec2AsrMetricBag +from fairseq2.recipes.wav2vec2.asr.common import ( + Wav2Vec2AsrCriterion, + Wav2Vec2AsrMetricBag, + Wav2Vec2AsrScorer, +) from fairseq2.recipes.wav2vec2.asr.eval import Wav2Vec2AsrEvalUnit -from fairseq2.typing import META, DataType, override +from fairseq2.typing import CPU, META, DataType from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed log = get_log_writer(__name__) -@dataclass +@dataclass(kw_only=True) class Wav2Vec2AsrTrainConfig: """Holds the configuration of a wav2vec 2.0 ASR model training task. @@ -100,11 +103,11 @@ class Wav2Vec2AsrTrainConfig: model_family: str = "wav2vec2_asr" """The family of the model.""" - model_arch: Optional[str] = "base_10h" + model_arch: str | None = "base_10h" """The architecture of the model.""" model_config: Any = None - """The model configuration.""" + """The configuration of the model.""" dtype: DataType = torch.float16 """The data type of the model.""" @@ -119,25 +122,28 @@ class Wav2Vec2AsrTrainConfig: """If ``True``, applies ``torch.compile()`` to the encoder. (experimental)""" # Optimizer, LR, and Loss - lr: float = 5e-05 - """The initial (post-warm-up) learning rate.""" - - betas: Tuple[float, float] = (0.9, 0.98) - """The coefficients of AdamW.""" + optimizer: str = "adamw" + """The optimizer.""" - lr_stage_ratios: Tuple[float, float, float] = (0.1, 0.4, 0.5) - """The ratios of tri-stage learning rate scheduler.""" + optimizer_config: Any = field( + default_factory=lambda: AdamWConfig(lr=5e-05, betas=(0.9, 0.98)) + ) + """The configuration of the optimizer.""" - start_lr_scale: float = 0.01 - """The scale of the initial warm-up learning rate.""" + lr_scheduler: str = "tri-stage" + """The learning rate scheduler.""" - final_lr_scale: float = 0.05 - """The scale of the final learning rate.""" + lr_scheduler_config: Any = field( + default_factory=lambda: TriStageLRConfig( + stage_ratio=(0.1, 0.4, 0.5), start_lr_scale=0.01, final_lr_scale=0.05 + ) + ) + """The configuration of the learning rate scheduler.""" - max_gradient_norm: Optional[float] = None + max_gradient_norm: float | None = None """The maximum gradient norm. If ``None``, no clipping will be applied.""" - fp16_loss_scale: Tuple[float, float] = (128.0, 0.0001) + fp16_loss_scale: tuple[float, float] = (128.0, 0.0001) """The initial and minimum loss scale for fp16 training.""" gradient_accumulation: int = 4 @@ -147,7 +153,7 @@ class Wav2Vec2AsrTrainConfig: max_num_steps: int = 20_000 """The maximum number of steps to train for.""" - max_num_data_epochs: Optional[int] = None + max_num_data_epochs: int | None = None """The maximum number of data epochs to train for.""" freeze_encoder_for_n_steps: int = 10_000 @@ -165,22 +171,22 @@ class Wav2Vec2AsrTrainConfig: checkpoint_every_n_steps: int = 1000 """The step interval at which to checkpoint.""" - keep_best_n_checkpoints: Optional[int] = None + keep_best_n_checkpoints: int | None = None """The number of checkpoints to keep based on their validation score. If ``None``, none will be deleted.""" publish_metrics_every_n_steps: int = 200 """The step interval at which to publish metrics.""" - # Checkpointing - resume_checkpoint_dir: Optional[Path] = None + # Checkpoint + resume_checkpoint_dir: Path | None = None """If not ``None``, adds the specified path to the default asset store.""" # Misc seed: int = 2 """The random number generator seed to use.""" - profile: Optional[Tuple[int, int]] = None + profile: tuple[int, int] | None = None """The number of steps that the PyTorch profiler should skip and then record.""" monitored_gang: bool = False @@ -204,15 +210,47 @@ def _base_10h() -> Wav2Vec2AsrTrainConfig: def _base_100h() -> Wav2Vec2AsrTrainConfig: config = _base_10h() + assert isinstance(config.optimizer_config, AdamWConfig) + config.dataset = "librispeech_asr_100h" config.model_arch = "base_100h" - config.lr = 0.00003 + config.optimizer_config.lr = 0.00003 config.max_num_steps = 50_000 config.freeze_encoder_for_n_steps = 0 return config +@wav2vec2_asr_train_preset("large_10h") +def _large_10h() -> Wav2Vec2AsrTrainConfig: + config = _base_10h() + + assert isinstance(config.optimizer_config, AdamWConfig) + + config.model_arch = "large_10h" + config.pretrained_model = "wav2vec2_large" + config.max_audio_len = 640_000 + config.max_num_elements = 1_280_000 + config.optimizer_config.lr = 0.0001 + config.gradient_accumulation = 5 + + return config + + +@wav2vec2_asr_train_preset("large_100h") +def _large_100h() -> Wav2Vec2AsrTrainConfig: + config = _large_10h() + + assert isinstance(config.optimizer_config, AdamWConfig) + + config.dataset = "librispeech_asr_100h" + config.model_arch = "large_100h" + config.optimizer_config.lr = 0.00003 + config.max_num_steps = 50_000 + + return config + + def load_wav2vec2_asr_trainer( config: Wav2Vec2AsrTrainConfig, output_dir: Path ) -> Trainer[Seq2SeqBatch]: @@ -221,19 +259,13 @@ def load_wav2vec2_asr_trainer( gang = setup_root_gang(log, monitored=config.monitored_gang) - checkpoint_manager = FileCheckpointManager( - output_dir.joinpath("checkpoints"), gang, lower_score_better=True - ) + checkpoint_manager = FileCheckpointManager(output_dir.joinpath("checkpoints"), gang) if config.resume_checkpoint_dir is not None: default_asset_store.metadata_providers.append( - CheckpointModelMetadataProvider( - config.resume_checkpoint_dir, lower_score_better=True - ) + CheckpointModelMetadataProvider(config.resume_checkpoint_dir) ) - seed = config.seed - tokenizer_card = retrieve_asset_card(config.tokenizer) # Load the tokenizer. @@ -260,7 +292,13 @@ def load_wav2vec2_asr_trainer( dataset = GenericAsrDataset.from_path(dataset_path) + seed = config.seed + # Initialize the model + manual_seed(seed, CPU, gang.device) + + seed += 1 + try: model, model_config = create_model( config.model_family, @@ -274,13 +312,15 @@ def load_wav2vec2_asr_trainer( "The model cannot be initialized. See nested exception for details." ) from ex - check_model_type(model, Wav2Vec2AsrModel) + if not isinstance(model, Wav2Vec2AsrModel): + raise ValueError( + f"The model must be of type `{Wav2Vec2AsrModel}`, but is of type `{type(model)}` instead." + ) log_model_config(model_config, log) - checkpoint_manager.save_model_metadata( - family=model.family, config=model_config, tokenizer_name=tokenizer_card.name - ) + checkpoint_manager.save_model_metadata(family=model.family, config=model_config) + checkpoint_manager.save_tokenizer_metadata(tokenizer_card.name) has_checkpoint = checkpoint_manager.has_checkpoint() @@ -309,9 +349,7 @@ def load_wav2vec2_asr_trainer( log.info("Pretrained model loaded on rank 0.") if gang.rank == 0: - to_device(model, gang.device, seed=seed) - - seed += 1 + to_device(model, gang.device) gang.barrier() @@ -328,7 +366,6 @@ def load_wav2vec2_asr_trainer( config.data_parallelism, log, ddp_find_unused_parameters=config.freeze_encoder_for_n_steps > 0, - fsdp_skip_init=True, fsdp_broadcast_state=not has_checkpoint, fsdp_mixed_precision_dtype=config.dtype, fsdp_fp32_reduce=True, @@ -340,19 +377,16 @@ def load_wav2vec2_asr_trainer( log_model(dp_model, log, rank=gang.rank) - # Initialize the train unit and the optimizer. + # Initialize the train criterion. + criterion = Wav2Vec2AsrCriterion(dp_model) + + # Initialize the train unit. unit = Wav2Vec2AsrTrainUnit( - dp_model, gang, freeze_encoder_for_n_steps=config.freeze_encoder_for_n_steps + criterion, gang, freeze_encoder_for_n_steps=config.freeze_encoder_for_n_steps ) - data_reader = dataset.create_reader( - config.train_split, - tokenizer, - gang, - batching=LengthBatching(config.max_num_elements), + options = AsrReadOptions( dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, normalize_audio=config.normalize_audio, example_shuffle_window=config.example_shuffle_window, batch_shuffle_window=config.batch_shuffle_window, @@ -361,36 +395,82 @@ def load_wav2vec2_asr_trainer( seed=seed, ) + try: + data_reader = dataset.create_reader( + config.train_split, + tokenizer, + gang, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + seed += 1 - optimizer = AdamW(dp_model.parameters(), lr=config.lr, betas=config.betas) + # Initialize the optimizer. + try: + optimizer = create_optimizer( + config.optimizer, dp_model, config.optimizer_config + ) + except ValueError as ex: + raise ValueError( + "The optimizer cannot be created. See nested exception for details." + ) from ex - lr_scheduler = TriStageLR( - optimizer, - config.max_num_steps, - config.lr_stage_ratios, - start_lr_scale=config.start_lr_scale, - final_lr_scale=config.final_lr_scale, - ) + # Initialize the learning rate scheduler. + try: + lr_scheduler = create_lr_scheduler( + config.lr_scheduler, + optimizer, + config.lr_scheduler_config, + max_num_steps=config.max_num_steps, + ) + except ValueError as ex: + raise ValueError( + "The learning rate scheduler cannot be created. See nested exception for details." + ) from ex + + # Initialize the validation criterion. + scorer = Wav2Vec2AsrScorer(tokenizer) + + valid_criterion = Wav2Vec2AsrCriterion(dp_model, scorer) # Initialize the validation unit. - valid_unit = Wav2Vec2AsrEvalUnit(dp_model, gang, tokenizer) + valid_unit = Wav2Vec2AsrEvalUnit(valid_criterion, gang) - valid_data_reader = dataset.create_reader( - config.valid_split, - tokenizer, - gang, - batching=LengthBatching(config.max_num_elements), + options = AsrReadOptions( dtype=config.dtype, - min_audio_len=config.min_audio_len, - max_audio_len=config.max_audio_len, normalize_audio=config.normalize_audio, + sync_mode="until_last", num_prefetch=config.num_prefetch, seed=seed, ) + try: + valid_data_reader = dataset.create_reader( + config.valid_split, + tokenizer, + gang, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader for the valid split cannot be initialized. See nested exception for details." + ) from ex + seed += 1 + # TODO: Fix once we support static mixed precision on one device. + amp = gang.size == 1 or config.data_parallelism != "fsdp" + # Initialize the trainer. return Trainer[Seq2SeqBatch]( unit=unit, @@ -401,6 +481,7 @@ def load_wav2vec2_asr_trainer( lr_scheduler=lr_scheduler, fp16_loss_scale=config.fp16_loss_scale, max_gradient_norm=config.max_gradient_norm, + amp=amp, max_num_steps=config.max_num_steps, max_num_data_epochs=config.max_num_data_epochs, score_metric_name="wer", @@ -425,50 +506,32 @@ def load_wav2vec2_asr_trainer( @final class Wav2Vec2AsrTrainUnit(AbstractTrainUnit[Seq2SeqBatch]): - """Represents a wav2vec 2.0 ASR model training unit.""" - + _criterion: Wav2Vec2AsrCriterion _freeze_encoder_for_n_steps: int _metric_bag: Wav2Vec2AsrMetricBag def __init__( self, - model: Module, + criterion: Wav2Vec2AsrCriterion, gang: Gang, *, freeze_encoder_for_n_steps: int = 0, ) -> None: """ - :param model: - The wav2vec 2.0 ASR model. Might be wrapped with DDP or FSDP. - :param gang: - The gang for distributed training. - :param freeze_encoder_for_n_steps: - The encoder will be frozen for this number of steps. + :param freeze_encoder_for_n_steps: The encoder will be frozen for this + number of steps. """ - super().__init__(model) + super().__init__(criterion.model) - check_model_type(model, Wav2Vec2AsrModel) + self._criterion = criterion self._freeze_encoder_for_n_steps = freeze_encoder_for_n_steps self._metric_bag = Wav2Vec2AsrMetricBag(gang) @override - def __call__(self, batch: Seq2SeqBatch) -> Tuple[Tensor, int]: - input_batch = SequenceBatch(batch.source_seqs, batch.source_padding_mask) - - output = self._forward(input_batch) - - loss = output.compute_loss(batch.target_seqs, batch.target_padding_mask) - - self._metric_bag.update_ctc_loss(batch, loss.detach()) - - self._metric_bag.update_batch_metrics(batch) - - return loss, batch.batch_size - - def _forward(self, batch: SequenceBatch) -> Wav2Vec2AsrOutput: - return self._model(batch) # type: ignore[no-any-return] + def __call__(self, batch: Seq2SeqBatch) -> tuple[Tensor, int]: + return self._criterion(batch, self._metric_bag) @override def set_step_nr(self, step_nr: int) -> None: diff --git a/src/fairseq2/recipes/wav2vec2/common.py b/src/fairseq2/recipes/wav2vec2/common.py new file mode 100644 index 000000000..b6bfa203b --- /dev/null +++ b/src/fairseq2/recipes/wav2vec2/common.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from typing import final + +import torch +from torch import Tensor +from torch.nn import Module +from torcheval.metrics import MulticlassAccuracy + +from fairseq2.gang import Gang +from fairseq2.metrics.aggregation import Mean +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.wav2vec2 import Wav2Vec2Loss, Wav2Vec2Model, Wav2Vec2Output +from fairseq2.models.wav2vec2.vector_quantizer import ( + GumbelVectorQuantizerOutput, + VectorQuantizerOutput, +) +from fairseq2.recipes.common_metrics import BaseMetricBag +from fairseq2.recipes.utils.setup import check_model_type + + +@final +class Wav2Vec2Criterion: + _model: Module + _diversity_loss_weight: float + _feature_penalty_weight: float + + def __init__( + self, model: Module, diversity_loss_weight: float, feature_penalty_weight: float + ) -> None: + check_model_type(model, Wav2Vec2Model) + + self._model = model + + self._diversity_loss_weight = diversity_loss_weight + self._feature_penalty_weight = feature_penalty_weight + + def __call__( + self, batch: SequenceBatch, metric_bag: Wav2Vec2MetricBag + ) -> tuple[Tensor, int]: + output = self._forward(batch) + + loss = output.compute_loss( + self._diversity_loss_weight, self._feature_penalty_weight + ) + + batch_size, seq_len = output.logits.shape[:2] + + num_targets = batch_size * seq_len + + metric_bag.update_losses(loss, num_targets) + + metric_bag.update_accuracy(output) + + metric_bag.update_quantizer_metrics(output.quantizer_output) + + metric_bag.update_batch_metrics(batch) + + return loss.total, num_targets + + def _forward(self, batch: SequenceBatch) -> Wav2Vec2Output: + return self._model(batch) # type: ignore[no-any-return] + + @property + def model(self) -> Module: + return self._model + + +class Wav2Vec2MetricBag(BaseMetricBag): + loss: Mean + contrastive_loss: Mean + diversity_loss: Mean + feature_penalty: Mean + accuracy: MulticlassAccuracy + code_perplexity: Mean + prob_perplexity: Mean + temperature: Mean + + def __init__(self, gang: Gang, train: bool = True) -> None: + super().__init__(gang, train=train) + + d = gang.device + + self.register_metric("loss", Mean(device=d), persistent=False) + + self.register_metric("contrastive_loss", Mean(device=d), persistent=False) + + self.register_metric("diversity_loss", Mean(device=d), persistent=False) + + self.register_metric("feature_penalty", Mean(device=d), persistent=False) + + self.register_metric("accuracy", MulticlassAccuracy(device=d), persistent=False) + + self.register_metric("code_perplexity", Mean(device=d), persistent=False) + + self.register_metric("prob_perplexity", Mean(device=d), persistent=False) + + self.register_metric("temperature", Mean(device=d), persistent=False) + + @torch.inference_mode() + def update_losses(self, loss: Wav2Vec2Loss, num_targets: int) -> None: + n = num_targets + + d = num_targets * math.log(2) + + self.loss.update(loss.total.detach() / d, weight=n) + + self.contrastive_loss.update(loss.contrastive.detach() / d, weight=n) + + self.diversity_loss.update(loss.diversity.detach() / d, weight=n) + + self.feature_penalty.update(loss.feature_penalty.detach() / d, weight=n) + + @torch.inference_mode() + def update_accuracy(self, output: Wav2Vec2Output) -> None: + # (N x S) + predictions = output.logits.argmax(-1).view(-1) + + # wav2vec2 treats logit at index 0 as the target. + targets = torch.zeros_like(predictions) + + self.accuracy.update(predictions, targets) + + @torch.inference_mode() + def update_quantizer_metrics(self, output: VectorQuantizerOutput) -> None: + if not isinstance(output, GumbelVectorQuantizerOutput): + return + + self.code_perplexity.update(output.code_perplexity) + self.prob_perplexity.update(output.prob_perplexity) + + self.temperature.update(output.temperature) + + @torch.inference_mode() + def update_batch_metrics(self, batch: SequenceBatch) -> None: + """Update the batch metrics.""" + num_examples = batch.batch_size + + num_elements = batch.num_elements() + + self.num_examples.update(num_examples) + self.num_elements.update(num_elements) + + if self._train: + assert self.total_num_examples is not None + assert self.total_num_elements is not None + + self.total_num_examples.update(num_examples) + self.total_num_elements.update(num_elements) diff --git a/src/fairseq2/recipes/wav2vec2/eval.py b/src/fairseq2/recipes/wav2vec2/eval.py new file mode 100644 index 000000000..0481c3fe1 --- /dev/null +++ b/src/fairseq2/recipes/wav2vec2/eval.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import final + +import torch +from typing_extensions import override + +from fairseq2.assets import AssetNotFoundError, default_asset_store +from fairseq2.checkpoint import CheckpointModelMetadataProvider +from fairseq2.config_registry import ConfigRegistry +from fairseq2.datasets import LengthBatching +from fairseq2.datasets.speech import ( + GenericSpeechDataset, + SpeechReadOptions, + load_speech_dataset, +) +from fairseq2.gang import Gang +from fairseq2.logging import get_log_writer +from fairseq2.models import load_model +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.wav2vec2 import Wav2Vec2Model +from fairseq2.nn.utils.module import remove_parametrizations +from fairseq2.recipes.evaluator import AbstractEvalUnit, Evaluator +from fairseq2.recipes.utils.asset import ( + AssetReference, + asset_as_path, + retrieve_asset_card, +) +from fairseq2.recipes.utils.log import log_model +from fairseq2.recipes.utils.setup import broadcast_model, setup_root_gang +from fairseq2.recipes.wav2vec2.common import Wav2Vec2Criterion, Wav2Vec2MetricBag +from fairseq2.typing import META, DataType +from fairseq2.utils.profiler import Stopwatch + +log = get_log_writer(__name__) + + +@dataclass(kw_only=True) +class Wav2Vec2EvalConfig: + """Holds the configuration of a wav2vec 2.0 model evaluation task.""" + + # Data + dataset: AssetReference = "librispeech_960h" + """The name, path or path to the asset card of the dataset to evaluate on.""" + + split: str = "valid" + """The name of the eval data split.""" + + min_audio_len: int = 32_000 + """The minimum audio sequence length.""" + + max_audio_len: int = 250_000 + """The maximum audio sequence length.""" + + max_num_elements: int = 1_500_000 + """The maximum number of elements per batch.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + # Model + model: AssetReference = "wav2vec2_base" + """The name or path to the asset card of the wav2vec 2.0 model to evaluate.""" + + checkpoint_dir: Path | None = None + """The checkpoint directory containing models saved by :class:`FileCheckpointManager`.""" + + dtype: DataType = torch.float16 + """The data type of the model.""" + + amp: bool = False + """If ``True``, runs evaluation with ``torch.amp``.""" + + # Loss + diversity_loss_weight: float = 0.1 + """The weight of the diversity loss.""" + + feature_penalty_weight: float = 10.0 + """The weight of the regularization penalty applied to the extracted features.""" + + # Misc + seed: int = 2 + """The random number generator seed to use.""" + + +wav2vec2_eval_presets = ConfigRegistry[Wav2Vec2EvalConfig]() + +wav2vec2_eval_preset = wav2vec2_eval_presets.decorator + + +@wav2vec2_eval_preset("base_ls960h") +def _base_ls960h() -> Wav2Vec2EvalConfig: + return Wav2Vec2EvalConfig() + + +@torch.inference_mode() +def load_wav2vec2_evaluator( + config: Wav2Vec2EvalConfig, output_dir: Path +) -> Evaluator[SequenceBatch]: + """Load an :class:`Evaluator` for wav2vec 2.0 model evaluation.""" + wall_watch = Stopwatch(start=True) + + if config.checkpoint_dir is not None: + default_asset_store.metadata_providers.append( + CheckpointModelMetadataProvider(config.checkpoint_dir) + ) + + gang = setup_root_gang(log) + + # Load the dataset. + try: + dataset_card = retrieve_asset_card(config.dataset) + except AssetNotFoundError: + dataset_card = None + + if dataset_card is not None: + log.info("Loading {} speech dataset.", dataset_card.name) + + dataset = load_speech_dataset(dataset_card) + + log.info("Dataset loaded.") + else: + dataset_path = asset_as_path(config.dataset) + + dataset = GenericSpeechDataset.from_path(dataset_path) + + model_card = retrieve_asset_card(config.model) + + # Load the model. + log.info("Loading {} model on rank 0.", model_card.name) + + if gang.rank == 0: + init_device = gang.device + else: + init_device = META + + try: + model = load_model(model_card, device=init_device, dtype=config.dtype) + except ValueError as ex: + raise ValueError( + "The model cannot be initialized. See nested exception for details." + ) from ex + + if not isinstance(model, Wav2Vec2Model): + raise ValueError( + f"The model must be of type `{Wav2Vec2Model}`, but is of type `{type(model)}` instead." + ) + + gang.barrier() + + log.info("Model loaded on rank 0.") + + remove_parametrizations(model) + + # Distribute the model to all processes in the gang. + if gang.size != 1: + broadcast_model(model, gang, log) + + log_model(model, log) + + # Initialize the criterion. + criterion = Wav2Vec2Criterion( + model, config.diversity_loss_weight, config.feature_penalty_weight + ) + + # Initialize the unit. + unit = Wav2Vec2EvalUnit(criterion, gang) + + seed = config.seed + + options = SpeechReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + + try: + data_reader = dataset.create_reader( + config.split, + gang, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + + seed += 1 + + # Initialize the evaluator. + return Evaluator[SequenceBatch]( + units=[unit], + data_readers=[data_reader], + root_gang=gang, + dtype=config.dtype, + amp=config.amp, + tb_dir=output_dir.joinpath("tb"), + metrics_dir=output_dir.joinpath("metrics"), + seed=seed, + wall_watch=wall_watch, + ) + + +@final +class Wav2Vec2EvalUnit(AbstractEvalUnit[SequenceBatch]): + _criterion: Wav2Vec2Criterion + _metric_bag: Wav2Vec2MetricBag + + def __init__(self, criterion: Wav2Vec2Criterion, gang: Gang) -> None: + super().__init__(criterion.model) + + self._criterion = criterion + + self._metric_bag = Wav2Vec2MetricBag(gang, train=False) + + @override + def __call__(self, batch: SequenceBatch) -> None: + self._criterion(batch, self._metric_bag) + + @property + @override + def metric_bag(self) -> Wav2Vec2MetricBag: + return self._metric_bag diff --git a/src/fairseq2/recipes/wav2vec2/train.py b/src/fairseq2/recipes/wav2vec2/train.py new file mode 100644 index 000000000..48955f38d --- /dev/null +++ b/src/fairseq2/recipes/wav2vec2/train.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, final + +import torch +from torch import Tensor +from typing_extensions import override + +from fairseq2.assets import AssetNotFoundError, default_asset_store +from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager +from fairseq2.config_registry import ConfigRegistry +from fairseq2.datasets import LengthBatching +from fairseq2.datasets.speech import ( + GenericSpeechDataset, + SpeechReadOptions, + load_speech_dataset, +) +from fairseq2.gang import Gang +from fairseq2.logging import get_log_writer +from fairseq2.models import create_model +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.wav2vec2 import Wav2Vec2Model +from fairseq2.optim import AdamWConfig, create_optimizer +from fairseq2.optim.lr_scheduler import PolynomialDecayLRConfig, create_lr_scheduler +from fairseq2.recipes.trainer import AbstractTrainUnit, Trainer +from fairseq2.recipes.utils.asset import ( + AssetReference, + asset_as_path, + retrieve_asset_card, +) +from fairseq2.recipes.utils.log import log_model, log_model_config +from fairseq2.recipes.utils.setup import ( + compile_model, + setup_root_gang, + to_data_parallel, +) +from fairseq2.recipes.wav2vec2.common import Wav2Vec2Criterion, Wav2Vec2MetricBag +from fairseq2.recipes.wav2vec2.eval import Wav2Vec2EvalUnit +from fairseq2.typing import CPU, META, DataType +from fairseq2.utils.profiler import Stopwatch +from fairseq2.utils.rng import manual_seed + +log = get_log_writer(__name__) + + +@dataclass(kw_only=True) +class Wav2Vec2TrainConfig: + """Holds the configuration of a wav2vec 2.0 model training task. + + The default values correspond to the base ls960h training setup as described + in :cite:t:`https://doi.org/10.48550/arxiv.2006.11477`. + """ + + # Data + dataset: AssetReference = "librispeech_960h" + """The name, path or path to the asset card of the speech dataset.""" + + train_split: str = "train" + """The name of the train data split.""" + + valid_split: str = "valid" + """The name of the valid data split.""" + + min_audio_len: int = 32_000 + """The minimum audio sequence length.""" + + max_audio_len: int = 250_000 + """The maximum audio sequence length.""" + + max_num_elements: int = 1_500_000 + """The maximum number of elements per batch.""" + + normalize_audio: bool = False + """If ``True``, normalizes audio to have zero mean and unit variance.""" + + batch_shuffle_window: int = 0 + """The size of the sliding window for shuffling batches.""" + + num_prefetch: int = 4 + """The number of batches to prefetch in background.""" + + # Model + model_family: str = "wav2vec2" + """The family of the model.""" + + model_arch: str | None = "base" + """The architecture of the wav2vec2 model.""" + + model_config: Any = None + """The configuration of the model.""" + + dtype: DataType = torch.float16 + """The data type of the model.""" + + data_parallelism: Literal["ddp", "fsdp"] = "ddp" + """The data parallelism API to use.""" + + fsdp_wrap_granularity: Literal["layer", "stack", "model"] = "stack" + """The granularity at which to wrap the model.""" + + torch_compile: bool = False + """If ``True``, applies ``torch.compile()`` to the encoder. (experimental)""" + + # Optimizer, LR, and Loss + optimizer: str = "adamw" + """The optimizer.""" + + optimizer_config: Any = field( + default_factory=lambda: AdamWConfig( + lr=5e-04, betas=(0.9, 0.98), eps=1e-06, weight_decay=0.01 + ) + ) + """The configuration of the optimizer.""" + + lr_scheduler: str = "polynomial-decay" + """The learning rate scheduler.""" + + lr_scheduler_config: Any = field( + default_factory=lambda: PolynomialDecayLRConfig(num_warmup_steps=32_000) + ) + """The configuration of the learning rate scheduler.""" + + max_gradient_norm: float | None = None + """The maximum gradient norm. If ``None``, no clipping will be applied.""" + + fp16_loss_scale: tuple[float, float] = (128.0, 0.0001) + """The initial and minimum loss scale for fp16 training.""" + + gradient_accumulation: int = 1 + """The number of steps to accumulate gradients before an optimizer update.""" + + diversity_loss_weight: float = 0.1 + """The weight of the diversity loss.""" + + feature_penalty_weight: float = 10.0 + """The weight of the regularization penalty applied to the extracted features.""" + + # Regime + max_num_steps: int = 400_000 + """The maximum number of steps to train for.""" + + max_num_data_epochs: int | None = None + """The maximum number of data epochs to train for.""" + + validate_every_n_steps: int = 5_000 + """The step interval at which to validate the model.""" + + checkpoint_every_n_steps: int = 25_000 + """The step interval at which to checkpoint.""" + + keep_best_n_checkpoints: int | None = 1 + """The number of checkpoints to keep based on their validation score. If + ``None``, none will be deleted.""" + + publish_metrics_every_n_steps: int = 200 + """The step interval at which to publish metrics.""" + + # Checkpoint + resume_checkpoint_dir: Path | None = None + """If not ``None``, adds the specified path to the default asset store.""" + + # Misc + seed: int = 2 + """The random number generator seed to use.""" + + profile: tuple[int, int] | None = None + """The number of steps that the PyTorch profiler should skip and then record.""" + + monitored_gang: bool = False + """If ``True``, puts a monitored barrier before every collective call.""" + + anomaly_detection: bool = False + """If ``True``, enables the anomaly detection feature of ``torch.autograd``.""" + + +wav2vec2_train_presets = ConfigRegistry[Wav2Vec2TrainConfig]() + +wav2vec2_train_preset = wav2vec2_train_presets.decorator + + +@wav2vec2_train_preset("base_960h") +def _base_960h() -> Wav2Vec2TrainConfig: + config = Wav2Vec2TrainConfig() + + config.model_config = {"encoder_config": {"first_pass_dropout_p": 0.1}} + + return config + + +@wav2vec2_train_preset("large_960h") +def _large_960h() -> Wav2Vec2TrainConfig: + config = Wav2Vec2TrainConfig() + + assert isinstance(config.optimizer_config, AdamWConfig) + assert isinstance(config.lr_scheduler_config, PolynomialDecayLRConfig) + + config.max_audio_len = 320_000 + config.max_num_elements = 1_200_000 + config.model_arch = "large" + config.model_config = {"encoder_config": {"first_pass_dropout_p": 0.1}} + config.optimizer_config.lr = 3e-04 + config.lr_scheduler_config.num_warmup_steps = 20_000 + config.max_num_steps = 250_000 + config.publish_metrics_every_n_steps = 100 + + return config + + +def load_wav2vec2_trainer( + config: Wav2Vec2TrainConfig, output_dir: Path +) -> Trainer[SequenceBatch]: + """Load a :class:`Trainer` for wav2vec 2.0 model training.""" + wall_watch = Stopwatch(start=True) + + gang = setup_root_gang(log, monitored=config.monitored_gang) + + checkpoint_manager = FileCheckpointManager(output_dir.joinpath("checkpoints"), gang) + + if config.resume_checkpoint_dir is not None: + default_asset_store.metadata_providers.append( + CheckpointModelMetadataProvider(config.resume_checkpoint_dir) + ) + + # Load the dataset. + try: + dataset_card = retrieve_asset_card(config.dataset) + except AssetNotFoundError: + dataset_card = None + + if dataset_card is not None: + log.info("Loading {} speech dataset.", dataset_card.name) + + dataset = load_speech_dataset(dataset_card) + + log.info("Dataset loaded.") + else: + dataset_path = asset_as_path(config.dataset) + + dataset = GenericSpeechDataset.from_path(dataset_path) + + seed = config.seed + + # Initialize the model + manual_seed(seed, CPU, gang.device) + + seed += 1 + + try: + model, model_config = create_model( + config.model_family, + config.model_arch, + config.model_config, + device=META, + dtype=torch.float32, + ) + except ValueError as ex: + raise ValueError( + "The model cannot be initialized. See nested exception for details." + ) from ex + + if not isinstance(model, Wav2Vec2Model): + raise ValueError( + f"The model must be of type `{Wav2Vec2Model}`, but is of type `{type(model)}` instead." + ) + + log_model_config(model_config, log) + + checkpoint_manager.save_model_metadata(family=model.family, config=model_config) + + has_checkpoint = checkpoint_manager.has_checkpoint() + + dp_model = to_data_parallel( + model, + gang, + config.data_parallelism, + log, + fsdp_broadcast_state=not has_checkpoint, + fsdp_mixed_precision_dtype=config.dtype, + fsdp_fp32_reduce=True, + fsdp_wrap_granularity=config.fsdp_wrap_granularity, + ) + + if config.torch_compile: + model.encoder = compile_model(model.encoder, log) # type: ignore[assignment] + + log_model(dp_model, log, rank=gang.rank) + + # Initialize the train criterion. + criterion = Wav2Vec2Criterion( + dp_model, config.diversity_loss_weight, config.feature_penalty_weight + ) + + # Initialize the train unit. + unit = Wav2Vec2TrainUnit(criterion, gang) + + options = SpeechReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + batch_shuffle_window=config.batch_shuffle_window, + num_accumulate=config.gradient_accumulation, + num_prefetch=config.num_prefetch, + seed=seed, + ) + + try: + data_reader = dataset.create_reader( + config.train_split, + gang, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader cannot be initialized. See nested exception for details." + ) from ex + + seed += 1 + + # Initialize the optimizer. + try: + optimizer = create_optimizer( + config.optimizer, dp_model, config.optimizer_config + ) + except ValueError as ex: + raise ValueError( + "The optimizer cannot be created. See nested exception for details." + ) from ex + + # Initialize the learning rate scheduler. + try: + lr_scheduler = create_lr_scheduler( + config.lr_scheduler, + optimizer, + config.lr_scheduler_config, + max_num_steps=config.max_num_steps, + ) + except ValueError as ex: + raise ValueError( + "The learning rate scheduler cannot be created. See nested exception for details." + ) from ex + + # Initialize the validation unit. + valid_unit = Wav2Vec2EvalUnit(criterion, gang) + + options = SpeechReadOptions( + dtype=config.dtype, + normalize_audio=config.normalize_audio, + sync_mode="until_last", + num_prefetch=config.num_prefetch, + seed=seed, + ) + + try: + valid_data_reader = dataset.create_reader( + config.valid_split, + gang, + config.min_audio_len, + config.max_audio_len, + LengthBatching(config.max_num_elements), + options, + ) + except ValueError as ex: + raise ValueError( + "The data reader for the valid split cannot be initialized. See nested exception for details." + ) from ex + + seed += 1 + + # TODO: Fix once we support static mixed precision on one device. + amp = gang.size == 1 or config.data_parallelism != "fsdp" + + # Initialize the trainer. + return Trainer[SequenceBatch]( + unit=unit, + data_reader=data_reader, + root_gang=gang, + dtype=config.dtype, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + fp16_loss_scale=config.fp16_loss_scale, + max_gradient_norm=config.max_gradient_norm, + amp=amp, + max_num_steps=config.max_num_steps, + max_num_data_epochs=config.max_num_data_epochs, + score_metric_name="loss", + lower_better=True, + valid_units=[valid_unit], + valid_data_readers=[valid_data_reader], + validate_after_n_steps=0, + validate_every_n_steps=config.validate_every_n_steps, + checkpoint_manager=checkpoint_manager, + checkpoint_after_n_steps=0, + checkpoint_every_n_steps=config.checkpoint_every_n_steps, + keep_best_n_checkpoints=config.keep_best_n_checkpoints, + tb_dir=output_dir.joinpath("tb"), + metrics_dir=output_dir.joinpath("metrics"), + publish_metrics_every_n_steps=config.publish_metrics_every_n_steps, + profile=config.profile, + anomaly_detection=config.anomaly_detection, + seed=seed, + wall_watch=wall_watch, + ) + + +@final +class Wav2Vec2TrainUnit(AbstractTrainUnit[SequenceBatch]): + _criterion: Wav2Vec2Criterion + _metric_bag: Wav2Vec2MetricBag + + def __init__(self, criterion: Wav2Vec2Criterion, gang: Gang) -> None: + super().__init__(criterion.model) + + self._criterion = criterion + + self._metric_bag = Wav2Vec2MetricBag(gang) + + @override + def __call__(self, batch: SequenceBatch) -> tuple[Tensor, int]: + return self._criterion(batch, self._metric_bag) + + @property + @override + def metric_bag(self) -> Wav2Vec2MetricBag: + return self._metric_bag diff --git a/src/fairseq2/setup/__init__.py b/src/fairseq2/setup/__init__.py new file mode 100644 index 000000000..132b949df --- /dev/null +++ b/src/fairseq2/setup/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.setup.assets import ( + register_package_metadata_provider as register_package_metadata_provider, +) +from fairseq2.setup.datasets import register_dataset as register_dataset +from fairseq2.setup.root import setup_fairseq2 as setup_fairseq2 +from fairseq2.setup.root import setup_runtime_context as setup_runtime_context +from fairseq2.setup.text_tokenizers import ( + register_text_tokenizer as register_text_tokenizer, +) diff --git a/src/fairseq2/setup/assets.py b/src/fairseq2/setup/assets.py new file mode 100644 index 000000000..86cf9dc1b --- /dev/null +++ b/src/fairseq2/setup/assets.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.assets import ( + FileAssetMetadataProvider, + PackageAssetMetadataProvider, + StandardAssetStore, + WheelPackageFileLister, + get_asset_dir, + get_user_asset_dir, +) +from fairseq2.context import RuntimeContext +from fairseq2.utils.file import StandardFileSystem +from fairseq2.utils.yaml import load_yaml + + +def _register_assets(context: RuntimeContext) -> None: + asset_store = context.asset_store + + # Package Metadata + register_package_metadata_provider(asset_store, "fairseq2.assets.cards") + + # /etc/fairseq2/assets + _register_asset_dir(asset_store) + + # ~/.config/fairseq2/assets + _register_user_asset_dir(asset_store) + + +def _register_asset_dir(asset_store: StandardAssetStore) -> None: + config_dir = get_asset_dir() + if config_dir is None: + return + + file_system = StandardFileSystem() + + provider = FileAssetMetadataProvider(config_dir, file_system, load_yaml) + + asset_store.metadata_providers.append(provider) + + +def _register_user_asset_dir(asset_store: StandardAssetStore) -> None: + config_dir = get_user_asset_dir() + if config_dir is None: + return + + file_system = StandardFileSystem() + + provider = FileAssetMetadataProvider(config_dir, file_system, load_yaml) + + asset_store.user_metadata_providers.append(provider) + + +def register_package_metadata_provider( + asset_store: StandardAssetStore, package_name: str +) -> None: + package_file_lister = WheelPackageFileLister() + + provider = PackageAssetMetadataProvider( + package_name, package_file_lister, load_yaml + ) + + asset_store.metadata_providers.append(provider) diff --git a/src/fairseq2/setup/chatbots.py b/src/fairseq2/setup/chatbots.py new file mode 100644 index 000000000..fb7f64efb --- /dev/null +++ b/src/fairseq2/setup/chatbots.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.chatbots import ChatbotHandler +from fairseq2.chatbots.llama import LLaMAChatbotHandler +from fairseq2.chatbots.mistral import MistralChatbotHandler +from fairseq2.context import RuntimeContext +from fairseq2.models.llama import LLAMA_FAMILY +from fairseq2.models.mistral import MISTRAL_FAMILY + + +def _register_chatbots(context: RuntimeContext) -> None: + registry = context.get_registry(ChatbotHandler) + + registry.register(LLAMA_FAMILY, LLaMAChatbotHandler()) + registry.register(MISTRAL_FAMILY, MistralChatbotHandler()) diff --git a/src/fairseq2/setup/clusters.py b/src/fairseq2/setup/clusters.py new file mode 100644 index 000000000..ba6365c74 --- /dev/null +++ b/src/fairseq2/setup/clusters.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.recipes.cluster import ClusterHandler, SlurmClusterHandler + + +def _register_clusters(context: RuntimeContext) -> None: + registry = context.get_registry(ClusterHandler) + + registry.register("slurm", SlurmClusterHandler()) diff --git a/src/fairseq2/setup/config.py b/src/fairseq2/setup/config.py new file mode 100644 index 000000000..1472d2c9c --- /dev/null +++ b/src/fairseq2/setup/config.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.generation import ( + AlgorithmSection, + AlgorithmSectionHandler, + BeamSearchAlgorithmHandler, + SamplerHandler, + SamplerSection, + SamplerSectionHandler, +) +from fairseq2.utils.config import ConfigSectionHandler + + +def _register_config_sections(context: RuntimeContext) -> None: + registry = context.get_registry(ConfigSectionHandler) + + handler: ConfigSectionHandler + + # Sampler + sampler_handlers = context.get_registry(SamplerHandler) + + handler = SamplerSectionHandler(sampler_handlers) + + registry.register(SamplerSection, handler) + + # Algorithm + algorithm_handlers = context.get_registry(BeamSearchAlgorithmHandler) + + handler = AlgorithmSectionHandler(algorithm_handlers) + + registry.register(AlgorithmSection, handler) diff --git a/src/fairseq2/setup/datasets.py b/src/fairseq2/setup/datasets.py new file mode 100644 index 000000000..ea4bc54ad --- /dev/null +++ b/src/fairseq2/setup/datasets.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.datasets import DatasetHandler, DatasetLoader, StandardDatasetHandler +from fairseq2.datasets.asr import GENERIC_ASR_DATASET_FAMILY, GenericAsrDataset +from fairseq2.datasets.instruction import ( + GENERIC_INSTRUCTION_DATASET_FAMILY, + GenericInstructionDataset, +) +from fairseq2.datasets.parallel_text import ( + GENERIC_PARALLEL_TEXT_DATASET_FAMILY, + GenericParallelTextDataset, +) +from fairseq2.datasets.preference import ( + GENERIC_PREFERENCE_OPTIMIZATION_DATASET_FAMILY, + GenericPreferenceOptimizationDataset, +) +from fairseq2.datasets.speech import GENERIC_SPEECH_DATASET_FAMILY, GenericSpeechDataset +from fairseq2.datasets.text import GENERIC_TEXT_DATASET_FAMILY, GenericTextDataset + + +def _register_datasets(context: RuntimeContext) -> None: + register_dataset( + context, + GENERIC_ASR_DATASET_FAMILY, + kls=GenericAsrDataset, + loader=GenericAsrDataset.from_path, + ) + + register_dataset( + context, + GENERIC_INSTRUCTION_DATASET_FAMILY, + kls=GenericInstructionDataset, + loader=GenericInstructionDataset.from_path, + ) + + register_dataset( + context, + GENERIC_PARALLEL_TEXT_DATASET_FAMILY, + kls=GenericParallelTextDataset, + loader=GenericParallelTextDataset.from_path, + ) + + register_dataset( + context, + GENERIC_PREFERENCE_OPTIMIZATION_DATASET_FAMILY, + kls=GenericPreferenceOptimizationDataset, + loader=GenericPreferenceOptimizationDataset.from_path, + ) + + register_dataset( + context, + GENERIC_SPEECH_DATASET_FAMILY, + kls=GenericSpeechDataset, + loader=GenericSpeechDataset.from_path, + ) + + register_dataset( + context, + GENERIC_TEXT_DATASET_FAMILY, + kls=GenericTextDataset, + loader=GenericTextDataset.from_path, + ) + + +def register_dataset( + context: RuntimeContext, family: str, *, kls: type, loader: DatasetLoader +) -> None: + handler = StandardDatasetHandler(kls, loader, context.asset_download_manager) + + registry = context.get_registry(DatasetHandler) + + registry.register(family, handler) diff --git a/src/fairseq2/setup/generation.py b/src/fairseq2/setup/generation.py new file mode 100644 index 000000000..cf180a560 --- /dev/null +++ b/src/fairseq2/setup/generation.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.generation import ( + BEAM_SEARCH_GENERATOR, + SAMPLING_GENERATOR, + STANDARD_BEAM_SEARCH_ALGO, + TOP_K_SAMPLER, + TOP_P_SAMPLER, + BeamSearchAlgorithmHandler, + BeamSearchSeq2SeqGeneratorHandler, + BeamSearchSequenceGeneratorHandler, + SamplerHandler, + SamplingSeq2SeqGeneratorHandler, + SamplingSequenceGeneratorHandler, + Seq2SeqGeneratorHandler, + SequenceGeneratorHandler, + StandardBeamSearchAlgorithmHandler, + TopKSamplerHandler, + TopPSamplerHandler, +) + + +def _register_seq_generators(context: RuntimeContext) -> None: + registry = context.get_registry(SequenceGeneratorHandler) + + handler: SequenceGeneratorHandler + + # Sampling + sampler_handlers = context.get_registry(SamplerHandler) + + handler = SamplingSequenceGeneratorHandler(sampler_handlers) + + registry.register(SAMPLING_GENERATOR, handler) + + # Beam Search + algorithm_handlers = context.get_registry(BeamSearchAlgorithmHandler) + + handler = BeamSearchSequenceGeneratorHandler(algorithm_handlers) + + registry.register(BEAM_SEARCH_GENERATOR, handler) + + +def _register_seq2seq_generators(context: RuntimeContext) -> None: + registry = context.get_registry(Seq2SeqGeneratorHandler) + + handler: Seq2SeqGeneratorHandler + + # Sampling + sampler_handlers = context.get_registry(SamplerHandler) + + handler = SamplingSeq2SeqGeneratorHandler(sampler_handlers) + + registry.register(SAMPLING_GENERATOR, handler) + + # Beam Search + algorithm_handlers = context.get_registry(BeamSearchAlgorithmHandler) + + handler = BeamSearchSeq2SeqGeneratorHandler(algorithm_handlers) + + registry.register(BEAM_SEARCH_GENERATOR, handler) + + +def _register_samplers(context: RuntimeContext) -> None: + registry = context.get_registry(SamplerHandler) + + registry.register(TOP_P_SAMPLER, TopPSamplerHandler()) + registry.register(TOP_K_SAMPLER, TopKSamplerHandler()) + + +def _register_beam_search_algorithms(context: RuntimeContext) -> None: + registry = context.get_registry(BeamSearchAlgorithmHandler) + + registry.register(STANDARD_BEAM_SEARCH_ALGO, StandardBeamSearchAlgorithmHandler()) diff --git a/src/fairseq2/setup/optim.py b/src/fairseq2/setup/optim.py new file mode 100644 index 000000000..4bd26de4f --- /dev/null +++ b/src/fairseq2/setup/optim.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.optim import ADAMW_OPTIMIZER, AdamWHandler, OptimizerHandler +from fairseq2.optim.lr_scheduler import ( + COSINE_ANNEALING_LR, + MYLE_LR, + NOAM_LR, + POLYNOMIAL_DECAY_LR, + TRI_STAGE_LR, + CosineAnnealingLRHandler, + LRSchedulerHandler, + MyleLRHandler, + NoamLRHandler, + PolynomialDecayLRHandler, + TriStageLRHandler, +) + + +def _register_optimizers(context: RuntimeContext) -> None: + registry = context.get_registry(OptimizerHandler) + + registry.register(ADAMW_OPTIMIZER, AdamWHandler()) + + +def _register_lr_schedulers(context: RuntimeContext) -> None: + registry = context.get_registry(LRSchedulerHandler) + + registry.register(COSINE_ANNEALING_LR, CosineAnnealingLRHandler()) + registry.register(MYLE_LR, MyleLRHandler()) + registry.register(NOAM_LR, NoamLRHandler()) + registry.register(POLYNOMIAL_DECAY_LR, PolynomialDecayLRHandler()) + registry.register(TRI_STAGE_LR, TriStageLRHandler()) diff --git a/src/fairseq2/setup/root.py b/src/fairseq2/setup/root.py new file mode 100644 index 000000000..ee60f321b --- /dev/null +++ b/src/fairseq2/setup/root.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.assets import InProcAssetDownloadManager, default_asset_store +from fairseq2.context import RuntimeContext, set_runtime_context +from fairseq2.extensions import run_extensions +from fairseq2.setup.assets import _register_assets +from fairseq2.setup.chatbots import _register_chatbots +from fairseq2.setup.clusters import _register_clusters +from fairseq2.setup.config import _register_config_sections +from fairseq2.setup.datasets import _register_datasets +from fairseq2.setup.generation import ( + _register_beam_search_algorithms, + _register_samplers, + _register_seq2seq_generators, + _register_seq_generators, +) +from fairseq2.setup.optim import _register_lr_schedulers, _register_optimizers +from fairseq2.setup.text_tokenizers import _register_text_tokenizers + +_setup_called: bool = False + + +def setup_fairseq2() -> None: + """ + Sets up fairseq2. + + As part of the initialization, this function also registers extensions + with via setuptools' `entry-point`__ mechanism. See + :doc:`/basics/runtime_extensions` for more information. + + .. important:: + + This function must be called before using any of the fairseq2 APIs. + + .. __: https://setuptools.pypa.io/en/latest/userguide/entry_point.html + """ + global _setup_called + + if _setup_called: + return + + _setup_called = True # Mark as called to avoid recursive calls. + + context = setup_runtime_context() + + set_runtime_context(context) + + run_extensions("fairseq2") # compat + + +def setup_runtime_context() -> RuntimeContext: + asset_download_manager = InProcAssetDownloadManager() + + context = RuntimeContext(default_asset_store, asset_download_manager) + + _register_assets(context) + _register_beam_search_algorithms(context) + _register_chatbots(context) + _register_clusters(context) + _register_config_sections(context) + _register_datasets(context) + _register_lr_schedulers(context) + _register_optimizers(context) + _register_samplers(context) + _register_seq2seq_generators(context) + _register_seq_generators(context) + _register_text_tokenizers(context) + + run_extensions("fairseq2.extension", context) + + return context diff --git a/src/fairseq2/setup/text_tokenizers.py b/src/fairseq2/setup/text_tokenizers.py new file mode 100644 index 000000000..f4ccadcc5 --- /dev/null +++ b/src/fairseq2/setup/text_tokenizers.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.data.text.tokenizers import ( + StandardTextTokenizerHandler, + TextTokenizerHandler, + TextTokenizerLoader, +) +from fairseq2.data.text.tokenizers.char_tokenizer import ( + CHAR_TOKENIZER_FAMILY, + load_char_tokenizer, +) +from fairseq2.data.text.tokenizers.llama import ( + LLAMA_TOKENIZER_FAMILY, + load_llama_tokenizer, +) +from fairseq2.data.text.tokenizers.mistral import ( + MISTRAL_TOKENIZER_FAMILY, + load_mistral_tokenizer, +) +from fairseq2.data.text.tokenizers.nllb import ( + NLLB_TOKENIZER_FAMILY, + load_nllb_tokenizer, +) +from fairseq2.data.text.tokenizers.s2t_transformer import ( + S2T_TRANSFORMER_TOKENIZER_FAMILY, + load_s2t_transformer_tokenizer, +) + + +def _register_text_tokenizers(context: RuntimeContext) -> None: + register_text_tokenizer( + context, + CHAR_TOKENIZER_FAMILY, + loader=load_char_tokenizer, + ) + + register_text_tokenizer( + context, + LLAMA_TOKENIZER_FAMILY, + loader=load_llama_tokenizer, + ) + + register_text_tokenizer( + context, + MISTRAL_TOKENIZER_FAMILY, + loader=load_mistral_tokenizer, + ) + + register_text_tokenizer( + context, + NLLB_TOKENIZER_FAMILY, + loader=load_nllb_tokenizer, + ) + + register_text_tokenizer( + context, + S2T_TRANSFORMER_TOKENIZER_FAMILY, + loader=load_s2t_transformer_tokenizer, + ) + + +def register_text_tokenizer( + context: RuntimeContext, family: str, *, loader: TextTokenizerLoader +) -> None: + handler = StandardTextTokenizerHandler( + loader=loader, asset_download_manager=context.asset_download_manager + ) + + registry = context.get_registry(TextTokenizerHandler) + + registry.register(family, handler) diff --git a/src/fairseq2/tensor_parallel.py b/src/fairseq2/tensor_parallel.py index cab199a05..507e1c00c 100644 --- a/src/fairseq2/tensor_parallel.py +++ b/src/fairseq2/tensor_parallel.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Tuple +from typing import Any import torch from torch import Tensor @@ -36,7 +36,7 @@ def forward(ctx: Any, x: Tensor, gang: Gang) -> Tensor: return x @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, None, None]: + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, None, None]: return grad_output, None, None @@ -58,7 +58,7 @@ def forward(ctx: Any, x: Tensor, gang: Gang, dim: int) -> Tensor: return _do_scatter(x, gang, dim) @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, None, None]: + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, None, None]: x = _do_gather(grad_output, ctx.gang, ctx.dim) return x, None, None @@ -82,7 +82,7 @@ def forward(ctx: Any, x: Tensor, gang: Gang, dim: int) -> Tensor: return _do_gather(x, gang, dim) @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, None, None]: + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, None, None]: x = _do_scatter(grad_output, ctx.gang, ctx.dim) return x, None, None @@ -101,7 +101,7 @@ def forward(ctx: Any, x: Tensor, gang: Gang) -> Tensor: return x @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, None]: + def backward(ctx: Any, grad_output: Tensor) -> tuple[Tensor, None]: ctx.gang.all_reduce(grad_output, ReduceOperation.SUM) return grad_output, None @@ -115,7 +115,7 @@ def _do_scatter(x: Tensor, gang: Gang, dim: int) -> Tensor: if dim_size % gang.size != 0: raise ValueError( - f"The size of the dimension {dim} of `x` must be divisible by `gang.size` ({gang.size}), but is {dim_size} instead." + f"The size of the dimension {dim} of `x` must be a multiple of `gang.size` ({gang.size}), but is {dim_size} instead." ) splits = x.split(dim_size // gang.size, dim=dim) diff --git a/src/fairseq2/typing.py b/src/fairseq2/typing.py index b2cd9bd23..e3a8d6350 100644 --- a/src/fairseq2/typing.py +++ b/src/fairseq2/typing.py @@ -7,29 +7,32 @@ from __future__ import annotations from dataclasses import Field, is_dataclass -from typing import Any, Callable, ClassVar, Dict, Final, Protocol, TypeVar +from typing import Any, ClassVar, Final, Protocol, TypeAlias, TypeGuard, TypeVar from torch import device, dtype -from typing_extensions import TypeAlias, TypeGuard +from typing_extensions import override as override # noqa: F401 +T = TypeVar("T") -class DataClass(Protocol): - """Represents a data class object.""" - __dataclass_fields__: ClassVar[Dict[str, Field[Any]]] +def safe_cast(param_name: str, value: object, kls: type[T]) -> T: + if not isinstance(value, kls): + raise TypeError( + f"`{param_name}` must be of type `{kls}`, but is of type `{type(value)}` instead." + ) + return value -def is_dataclass_instance(obj: Any) -> TypeGuard[DataClass]: - """Return ``True`` if ``obj`` is of type :class:`DataClass`.""" - return is_dataclass(obj) and not isinstance(obj, type) +class DataClass(Protocol): + """Represents a data class object.""" -F = TypeVar("F", bound=Callable[..., Any]) + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] -def override(f: F) -> F: - """Indicate that the decorated member overrides an inherited virtual member.""" - return f +def is_dataclass_instance(obj: object) -> TypeGuard[DataClass]: + """Return ``True`` if ``obj`` is of type :class:`DataClass`.""" + return is_dataclass(obj) and not isinstance(obj, type) Device: TypeAlias = device diff --git a/src/fairseq2/utils/config.py b/src/fairseq2/utils/config.py new file mode 100644 index 000000000..55eee6c6d --- /dev/null +++ b/src/fairseq2/utils/config.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import fields, is_dataclass + +from fairseq2.context import Provider, get_runtime_context +from fairseq2.utils.structured import StructureError + + +class ConfigProcessor: + _section_handlers: Provider[ConfigSectionHandler] + + def __init__(self, section_handlers: Provider[ConfigSectionHandler]) -> None: + self._section_handlers = section_handlers + + def process(self, config: object) -> None: + config_kls = type(config) + + if not is_dataclass(config_kls): + return + + try: + section_handler = self._section_handlers.get(config_kls) + except LookupError: + pass + else: + section_handler.process(config) + + for field in fields(config_kls): + field_value = getattr(config, field.name) + + try: + self.process(field_value) + except StructureError as ex: + raise StructureError( + f"`{field.name}` cannot be structured. See the nested exception for details." + ) from ex + + +class ConfigSectionHandler(ABC): + @abstractmethod + def process(self, section: object) -> None: + ... + + +def process_config(config: object) -> None: + context = get_runtime_context() + + config_section_handlers = context.get_registry(ConfigSectionHandler) + + config_processor = ConfigProcessor(config_section_handlers) + + config_processor.process(config) diff --git a/src/fairseq2/utils/dataclass.py b/src/fairseq2/utils/dataclass.py index a26717a69..98c568d29 100644 --- a/src/fairseq2/utils/dataclass.py +++ b/src/fairseq2/utils/dataclass.py @@ -6,149 +6,95 @@ from __future__ import annotations -from dataclasses import asdict, fields -from typing import Any, Dict, List, Mapping, Optional, TextIO, cast, get_type_hints +from collections.abc import MutableMapping +from dataclasses import MISSING, fields +from typing import Any, TypeVar, cast -import yaml +from typing_extensions import Self from fairseq2.typing import DataClass, is_dataclass_instance -from fairseq2.utils.value_converter import ValueConverter, default_value_converter -def update_dataclass( - obj: DataClass, - overrides: Mapping[str, Any], - *, - value_converter: Optional[ValueConverter] = None, -) -> List[str]: - """Update ``obj`` with the data contained in ``overrides``. +class _EmptyType: + def __reduce__(self) -> str: + return "EMPTY" - :param obj: - The data class instance to update. - :param overrides: - The dictionary containing the data to set in ``obj``. - :param value_converter: - The :class:`ValueConverter` instance to use. If ``None``, the default - instance will be used. - """ - if value_converter is None: - value_converter = default_value_converter + def __copy__(self) -> Self: + return self - unknown_fields: List[str] = [] + def __deepcopy__(self, memo: MutableMapping[Any, Any]) -> Self: + return self - field_path: List[str] = [] + def __repr__(self) -> str: + return "" - def update(obj_: DataClass, overrides_: Mapping[str, Any]) -> None: - overrides_copy = {**overrides_} - type_hints = get_type_hints(type(obj_)) +EMPTY = _EmptyType() +"""A sentinel signifying no value for a dataclass field.""" - for field in fields(obj_): - value = getattr(obj_, field.name) - try: - override = overrides_copy.pop(field.name) - except KeyError: - continue +T = TypeVar("T", bound=DataClass) - # Recursively traverse child dataclasses. - if override is not None and is_dataclass_instance(value): - if not isinstance(override, Mapping): - pathname = ".".join(field_path + [field.name]) - raise FieldError( - pathname, f"The field '{pathname}' is expected to be of type `{type(value)}`, but is of type `{type(override)}` instead." # fmt: skip - ) - - field_path.append(field.name) - - update(value, override) - - field_path.pop() - else: - type_hint = type_hints[field.name] - - try: - override = value_converter.structure(override, type_hint) - except (TypeError, ValueError) as ex: - pathname = ".".join(field_path + [field.name]) - - raise FieldError( - pathname, f"The value of the field '{pathname}' cannot be parsed. See nested exception for details" # fmt: skip - ) from ex - - setattr(obj_, field.name, override) - - if overrides_copy: - unknown_fields.extend( - ".".join(field_path + [name]) for name in overrides_copy - ) - - update(obj, overrides) - - unknown_fields.sort() - - return unknown_fields - - -class FieldError(RuntimeError): - """Raised when a dataclass field cannot be parsed.""" +def merge_dataclass(target: T, source: T) -> T: + """Merge ``target`` with the data contained in ``source``.""" + if type(target) is not type(source): + raise TypeError( + f"`target` and `source` must be of the same type, but they are of types `{type(target)}` and `{type(source)}` instead." + ) - _field_name: str + return cast(T, _copy_dataclass(target, source)) - def __init__(self, field_name: str, message: str) -> None: - super().__init__(message) - self._field_name = field_name +def _copy_dataclass(target: DataClass, source: DataClass) -> DataClass: + kls = type(target) - @property - def field_name(self) -> str: - return self._field_name + kwargs = {} + for field in fields(kls): + if not field.init: + continue -def dump_dataclass(obj: DataClass, output_stream: TextIO) -> None: - """Dump ``obj`` to ``fp`` in YAML format.""" - yaml.safe_dump(to_safe_dict(obj), output_stream, sort_keys=False) + source_value = getattr(source, field.name) + if source_value is EMPTY: + value = getattr(target, field.name) + else: + if is_dataclass_instance(source_value): + target_value = getattr(target, field.name) + if type(target_value) is type(source_value): + value = _copy_dataclass(target_value, source_value) + else: + value = _copy_dataclass_with_defaults(source_value) + else: + value = source_value -def to_safe_dict( - obj: DataClass, value_converter: Optional[ValueConverter] = None -) -> Dict[str, Any]: - """Convert ``obj`` to a :class:`dict` safe to serialize in YAML.""" - if value_converter is None: - value_converter = default_value_converter + kwargs[field.name] = value - try: - data = value_converter.unstructure(asdict(obj)) - except TypeError as ex: - raise ValueError( - "`obj` must contain only values that can be serialized to standard YAML. See nested exception for details." - ) from ex + return kls(**kwargs) - def sanity_check(data_: Any) -> None: - if data_ is None: - return - if isinstance(data_, (bool, int, float, str)): - return +def _copy_dataclass_with_defaults(obj: DataClass) -> DataClass: + kls = type(obj) - if isinstance(data_, list): - for e in data_: - sanity_check(e) + kwargs = {} - return + for field in fields(kls): + if not field.init: + continue - if isinstance(data_, dict): - for k, v in data_.items(): - sanity_check(k) - sanity_check(v) + value = getattr(obj, field.name) + if value is EMPTY: + if field.default == MISSING or field.default_factory == MISSING: + raise ValueError( + f"The `{field.name}` field of `{kls}` in `target` must have a default value or factory." + ) - return + continue - raise RuntimeError( - f"Unstructured output of `obj` must contain only primitive types, lists, and dicts, but it contains a value of type `{type(data_)}`." - ) + if is_dataclass_instance(value): + value = _copy_dataclass_with_defaults(value) - sanity_check(data) + kwargs[field.name] = value - return cast(Dict[str, Any], data) + return kls(**kwargs) diff --git a/src/fairseq2/utils/env.py b/src/fairseq2/utils/env.py index 9c9771a60..6434406e4 100644 --- a/src/fairseq2/utils/env.py +++ b/src/fairseq2/utils/env.py @@ -8,19 +8,19 @@ import os from pathlib import Path -from typing import Optional -from fairseq2.logging import LogWriter +from fairseq2.logging import log +from fairseq2.typing import Device -def get_int_from_env(var_name: str, allow_zero: bool = False) -> Optional[int]: +def get_int_from_env(var_name: str, allow_zero: bool = False) -> int | None: """Return the value of an environment variable as ``int``. :param var_name: The name of the environment variable. :param allow_zero: If ``True``, returns the value if it equals to zero; otherwise, raises - a :class:`RuntimeError`. + a :class:`InvalidEnvironmentVariableError`. """ s = os.environ.get(var_name) if s is None: @@ -29,27 +29,25 @@ def get_int_from_env(var_name: str, allow_zero: bool = False) -> Optional[int]: try: value = int(s) except ValueError: - raise RuntimeError( - f"The value of the `{var_name}` environment variable must be an integer, but is '{s}' instead." + raise InvalidEnvironmentVariableError( + f"The value of the `{var_name}` environment variable is expected to be an integer, but is '{s}' instead." ) from None if not allow_zero: if not value >= 1: - raise RuntimeError( - f"The value of the `{var_name}` environment variable must be greater than 0, but is {value} instead." + raise InvalidEnvironmentVariableError( + f"The value of the `{var_name}` environment variable is expected to be a positive integer, but is {value} instead." ) else: if not value >= 0: - raise RuntimeError( - f"The value of the `{var_name}` environment variable must be greater than or equal to 0, but is {value} instead." + raise InvalidEnvironmentVariableError( + f"The value of the `{var_name}` environment variable is expected to be greater than or equal to 0, but is {value} instead." ) return value -def get_path_from_env( - var_name: str, log: LogWriter, missing_ok: bool = False -) -> Optional[Path]: +def get_path_from_env(var_name: str, missing_ok: bool = False) -> Path | None: """Return the value of an environment variable as :class:`~pathlib.Path`. :param var_name: @@ -58,7 +56,7 @@ def get_path_from_env( The log to write to. :param missing_ok: If ``True``, returns ``None`` if the path does not exist; otherwise, - raises a :class:`RuntimeError`. + raises a :class:`InvalidEnvironmentVariableError`. """ pathname = os.environ.get(var_name) if not pathname: @@ -66,10 +64,10 @@ def get_path_from_env( try: path = Path(pathname) - except ValueError as ex: - raise RuntimeError( - f"The value of the `{var_name}` environment variable must be a pathname, but is '{pathname}' instead." - ) from ex + except ValueError: + raise InvalidEnvironmentVariableError( + f"The value of the `{var_name}` environment variable is expected to be a pathname, but is '{pathname}' instead." + ) from None resolved_path = path.expanduser().resolve() @@ -77,8 +75,25 @@ def get_path_from_env( if missing_ok: return resolved_path - log.warning("The path '{}' pointed to by the `{}` environment variable does not exist.", path, var_name) # fmt: skip + log.warning("The '{}' path pointed to by the `{}` environment variable does not exist.", path, var_name) # fmt: skip return None return resolved_path + + +def get_device_from_env(var_name: str) -> Device | None: + device_str = os.environ.get(var_name) + if device_str is None: + return None + + try: + return Device(device_str) + except (RuntimeError, ValueError): + raise InvalidEnvironmentVariableError( + f"The value of the `{var_name}` environment variable is expected to specify a PyTorch device, but is '{device_str}' instead." + ) from None + + +class InvalidEnvironmentVariableError(Exception): + pass diff --git a/src/fairseq2/utils/file.py b/src/fairseq2/utils/file.py index a9396c08f..71d3dc40c 100644 --- a/src/fairseq2/utils/file.py +++ b/src/fairseq2/utils/file.py @@ -6,32 +6,76 @@ from __future__ import annotations +import os import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Mapping, Sequence from pathlib import Path -from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union +from pickle import PickleError +from typing import Protocol, TypeAlias, final from warnings import catch_warnings import torch from torch import Tensor -from typing_extensions import TypeAlias +from typing_extensions import override +from fairseq2.error import NotSupportedError from fairseq2.typing import Device -MapLocation: TypeAlias = Optional[ - Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]] -] + +class FileSystem(ABC): + @abstractmethod + def is_file(self, path: Path) -> bool: + ... + + @abstractmethod + def make_directory(self, path: Path) -> None: + ... + + @abstractmethod + def walk_directory( + self, path: Path, *, on_error: Callable[[OSError], None] | None + ) -> Iterable[tuple[str, Sequence[str]]]: + ... + + @abstractmethod + def resolve(self, path: Path) -> Path: + ... + + +@final +class StandardFileSystem(FileSystem): + @override + def is_file(self, path: Path) -> bool: + return path.is_file() + + @override + def make_directory(self, path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + @override + def walk_directory( + self, path: Path, *, on_error: Callable[[OSError], None] | None + ) -> Iterable[tuple[str, Sequence[str]]]: + for dir_pathname, _, filenames in os.walk(path, onerror=on_error): + yield dir_pathname, filenames + + @override + def resolve(self, path: Path) -> Path: + return path.expanduser().resolve() + + +MapLocation: TypeAlias = ( + Callable[[Tensor, str], Tensor] | Device | str | dict[str, str] | None +) class TensorLoader(Protocol): """Loads tensors from files.""" def __call__( - self, - path: Path, - *, - map_location: MapLocation = None, - restrict: bool = False, - ) -> Dict[str, Any]: + self, path: Path, *, map_location: MapLocation = None, restrict: bool = False + ) -> dict[str, object]: """ :param path: The path to the file. @@ -46,7 +90,7 @@ def __call__( class TensorDumper(Protocol): """Dumps tensors to files.""" - def __call__(self, data: Mapping[str, Any], path: Path) -> None: + def __call__(self, data: Mapping[str, object], path: Path) -> None: """ :param data: The dictionary containing tensors and other auxiliary data. @@ -55,23 +99,119 @@ def __call__(self, data: Mapping[str, Any], path: Path) -> None: """ -def load_tensors( - path: Path, - *, - map_location: MapLocation = None, - restrict: bool = False, -) -> Dict[str, Any]: +def load_torch_tensors( + path: Path, *, map_location: MapLocation = None, restrict: bool = False +) -> dict[str, object]: """Load the PyTorch tensor file stored under ``path``.""" with catch_warnings(): - warnings.simplefilter("ignore") # Suppress the deprecation warning. - - data: Dict[str, Any] = torch.load( - str(path), map_location, weights_only=restrict - ) + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + try: + data: dict[str, object] = torch.load( + str(path), map_location, weights_only=restrict # type: ignore[arg-type] + ) + except FileNotFoundError: + raise + except (RuntimeError, OSError, PickleError) as ex: + raise TensorLoadError( + f"The '{path}' tensor file cannot be loaded. See the nested exception for details." + ) from ex return data -def dump_tensors(data: Mapping[str, Any], path: Path) -> None: +def dump_torch_tensors(data: Mapping[str, object], path: Path) -> None: """Dump ``data`` to a PyTorch tensor file under ``path``.""" - torch.save(data, path) + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + try: + torch.save(data, path) + except (RuntimeError, OSError, PickleError) as ex: + raise TensorDumpError( + f"The '{path}' tensor file cannot be dumped. See the nested exception for details.", + ) from ex + + +def load_safetensors( + path: Path, *, map_location: MapLocation = None, restrict: bool = False +) -> dict[str, object]: + """Load the Hugging Face Safetensors file(s) stored under ``path``.""" + try: + from safetensors import safe_open # type: ignore[import-not-found] + except ImportError: + raise NotSupportedError( + "Safetensors not found in your Python environment. Use `pip install safetensors`." + ) + + if map_location is not None: + if not isinstance(map_location, (Device, str)): + raise NotSupportedError( + "Safetensors only supports `torch.device` and `str` for the `map_location` parameter." + ) + + if path.is_dir(): + files = list(path.glob("*.safetensors")) + if not files: + raise TensorLoadError( + f"No Safetensors file found under the '{path}' directory." + ) + else: + files = [path] + + tensors = {} + + for file in files: + try: + with safe_open(file, framework="pt", device=str(map_location)) as f: # type: ignore[attr-defined] + for k in f.keys(): + if k in tensors: + raise TensorLoadError( + f"The '{k}' key exists in more than one Safetensors file under the '{path}' directory." + ) + + tensors[k] = f.get_tensor(k) + except FileNotFoundError: + raise + except (RuntimeError, OSError, PickleError) as ex: + raise TensorLoadError( + f"The '{file}' tensor file cannot be loaded. See the nested exception for details." + ) from ex + + return tensors + + +def load_tensors( + path: Path, *, map_location: MapLocation = None, restrict: bool = False +) -> dict[str, object]: + """Load the tensors stored under ``path``.""" + + def has_files(path: Path, extension: str) -> bool: + try: + next(iter(path.glob("*" + extension))) + except StopIteration: + return False + + return True + + if path.is_dir(): + if not has_files(path, ".safetensors"): + raise TensorLoadError( + f"The '{path}' directory does not contain any supported tensor files." + ) + + loader = load_safetensors + elif path.suffix == ".safetensors": + loader = load_safetensors + else: + loader = load_torch_tensors + + return loader(path, map_location=map_location, restrict=restrict) + + +class TensorLoadError(Exception): + pass + + +class TensorDumpError(Exception): + pass diff --git a/src/fairseq2/utils/lazy.py b/src/fairseq2/utils/lazy.py new file mode 100644 index 000000000..f0a662d23 --- /dev/null +++ b/src/fairseq2/utils/lazy.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Callable +from typing import Generic, TypeVar, cast, final + +T_co = TypeVar("T_co", covariant=True) + + +@final +class Lazy(Generic[T_co]): + _value: T_co | Callable[[], T_co] + _constructed: bool + + def __init__(self, factory: Callable[[], T_co]): + self._value = factory + self._constructed = False + + def retrieve(self) -> T_co: + if not self._constructed: + self._value = cast(Callable[[], T_co], self._value)() + + self._constructed = True + + return cast(T_co, self._value) diff --git a/src/fairseq2/utils/profiler.py b/src/fairseq2/utils/profiler.py index 4f2981d22..d9a356d6f 100644 --- a/src/fairseq2/utils/profiler.py +++ b/src/fairseq2/utils/profiler.py @@ -8,7 +8,7 @@ from pathlib import Path from time import perf_counter -from typing import Any, Optional, final +from typing import Any, final import torch from torch.profiler import ( @@ -19,6 +19,7 @@ ) from typing_extensions import Self +from fairseq2.error import InvalidOperationError from fairseq2.gang import Gang from fairseq2.typing import Device @@ -27,7 +28,7 @@ class Profiler: """Represents a convenience wrapper for :class:`profile`.""" - _profile: Optional[profile] + _profile: profile | None def __init__( self, @@ -100,7 +101,7 @@ def __exit__(self, *exc: Any) -> None: self.stop() @property - def wrapped_profile(self) -> Optional[profile]: + def wrapped_profile(self) -> profile | None: """The wrapped :class:`profile` instance.""" return self._profile @@ -109,10 +110,10 @@ def wrapped_profile(self) -> Optional[profile]: class Stopwatch: """Measures elapsed execution time.""" - _start_time: Optional[float] - _device: Optional[Device] + _start_time: float | None + _device: Device | None - def __init__(self, *, start: bool = False, device: Optional[Device] = None) -> None: + def __init__(self, *, start: bool = False, device: Device | None = None) -> None: """ :param start: If ``True``, starts the stopwatch immediately. @@ -122,6 +123,13 @@ def __init__(self, *, start: bool = False, device: Optional[Device] = None) -> N negative impact on the runtime performance if not used carefully. """ self._start_time = None + + if device is not None: + if device.type != "cpu" and device.type != "cuda": + raise ValueError( + f"The type of `device` must be `cpu` or `cuda`, but is `{device.type}` instead." + ) + self._device = device if start: @@ -130,7 +138,7 @@ def __init__(self, *, start: bool = False, device: Optional[Device] = None) -> N def start(self) -> None: """Start the stopwatch.""" if self._start_time is not None: - raise RuntimeError("The stopwatch is already running.") + raise InvalidOperationError("The stopwatch is already running.") self._sync_device() @@ -143,7 +151,7 @@ def stop(self) -> None: def reset(self) -> None: """Reset the stopwatch.""" if self._start_time is None: - raise RuntimeError("The stopwatch is not running.") + raise InvalidOperationError("The stopwatch is not running.") self._sync_device() diff --git a/src/fairseq2/utils/rng.py b/src/fairseq2/utils/rng.py index 992397f64..7a5abfb1b 100644 --- a/src/fairseq2/utils/rng.py +++ b/src/fairseq2/utils/rng.py @@ -6,18 +6,9 @@ from __future__ import annotations -from contextlib import contextmanager, nullcontext -from typing import ( - Any, - ContextManager, - Dict, - Iterable, - Iterator, - List, - Mapping, - Optional, - final, -) +from collections.abc import Iterator, Mapping +from contextlib import AbstractContextManager, contextmanager +from typing import final import torch from torch import Generator, Tensor @@ -43,7 +34,7 @@ def use_deterministic(value: bool, warn_only: bool = False) -> None: class RngBag: """Holds a collection of random number generators.""" - _generators: List[Generator] + _generators: list[Generator] def __init__(self, *generators: Generator) -> None: """ @@ -54,12 +45,12 @@ def __init__(self, *generators: Generator) -> None: @staticmethod def from_device_defaults(*devices: Device) -> RngBag: - """Create an :class:`RngBag` from the random number generators of ``devices``.""" + """Make an :class:`RngBag` from the random number generators of ``devices``.""" unique_devices = set() generators = [] - for device in devices: + for idx, device in enumerate(devices): if device in unique_devices or device.type == "meta": continue @@ -78,7 +69,7 @@ def from_device_defaults(*devices: Device) -> RngBag: generators.append(torch.cuda.default_generators[idx]) else: raise ValueError( - f"`devices` must be of type `cpu` or `cuda`, but at least one device is of type `{device.type}` instead." + f"`devices` must be of type `cpu` or `cuda`, but the device at index {idx} is of type `{device.type}` instead." ) return RngBag(*generators) @@ -122,49 +113,46 @@ def temporary_manual_seed(self, seed: int) -> Iterator[None]: for g, s in zip(self._generators, original_states): g.set_state(s) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, object]: return {"generators": [g.get_state() for g in self._generators]} - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: try: states = state_dict["generators"] except KeyError: - raise ValueError( - "`state_dict` must contain an item named `generators`." - ) from None + raise ValueError("`state_dict` must contain a 'generators' key.") from None + + if len(state_dict) != 1: + raise ValueError("`state_dict` must contain only a 'generators' key.") if not isinstance(states, list): raise TypeError( - f"The `generators` item of `state_dict` must be of type `{list}`, but is of type `{type(states)}` instead." + f"`state_dict['generators']` must be of type `{list}`, but is of type `{type(states)}` instead." ) if len(states) != len(self._generators): raise ValueError( - f"The number of generators in `state_dict` must match the number of generators in the bag ({len(self._generators)}), but is {len(states)} instead." + f"The number of generators in `state_dict['generators']` must match the number of generators in the bag ({len(self._generators)}), but is {len(states)} instead." ) for idx, state in enumerate(states): if not isinstance(state, Tensor): raise TypeError( - f"The generator states in `state_dict` must be of type `{Tensor}`, but the element at index {idx} is of type `{type(state)}` instead." + f"The generator states in `state_dict['generators']` must be of type `Tensor`, but the item at index {idx} is of type `{type(state)}` instead." ) self._generators[idx].set_state(state.clone()) -def temporary_manual_seed( - devices: Iterable[Device], seed: Optional[int] -) -> ContextManager[None]: - """Temporarily change the seed of the random number generators of ``devices``. +def manual_seed(seed: int, *devices: Device) -> None: + """Change the seed of the random number generators of ``devices``.""" + rng_bag = RngBag.from_device_defaults(*devices) + + rng_bag.manual_seed(seed) - :param devices: - The devices whose random number generators will be updated. - :param seed: - The seed to set. If ``None``, becomes a no-op. - """ - if seed is None: - return nullcontext() +def temporary_manual_seed(seed: int, *devices: Device) -> AbstractContextManager[None]: + """Temporarily change the seed of the random number generators of ``devices``.""" rng_bag = RngBag.from_device_defaults(*devices) return rng_bag.temporary_manual_seed(seed) diff --git a/src/fairseq2/utils/state.py b/src/fairseq2/utils/state.py index 7302e5b03..1bc74e12a 100644 --- a/src/fairseq2/utils/state.py +++ b/src/fairseq2/utils/state.py @@ -7,36 +7,28 @@ from __future__ import annotations import logging +import warnings from abc import ABC, abstractmethod -from typing import ( - Any, - Dict, - Generic, - Mapping, - Optional, - Protocol, - Set, - Tuple, - TypeVar, - final, - runtime_checkable, -) +from collections.abc import Mapping +from typing import Any, Generic, Protocol, TypeVar, final, runtime_checkable +from warnings import catch_warnings from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Module from torch.optim import Optimizer +from typing_extensions import override -from fairseq2.typing import override +from fairseq2.nn.utils.module import load_state_dict @runtime_checkable class Stateful(Protocol): """Represents an object that follows the ``state_dict`` convention.""" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, object]: ... - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: ... @@ -46,8 +38,8 @@ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: class StatefulObjectBag: """Holds a collection of stateful objects.""" - _non_stateful_attrs: Set[str] - _explicit_stateful_attrs: Dict[str, Optional[StateHandler[Any]]] + _non_stateful_attrs: set[str] + _explicit_stateful_attrs: dict[str, StateHandler[Any] | None] def __init__(self) -> None: super().__init__() # play nicely as a mixin. @@ -74,7 +66,7 @@ def register_stateful( self, name: str, obj: StatefulT, - state_handler: Optional[StateHandler[StatefulT]] = None, + state_handler: StateHandler[StatefulT] | None = None, ) -> None: """Add ``obj`` to the bag and preserve its state in ``state_dict``. @@ -97,7 +89,7 @@ def register_stateful( setattr(self, name, obj) @final - def register_non_stateful(self, name: str, obj: Any) -> None: + def register_non_stateful(self, name: str, obj: object) -> None: """Add ``obj`` to the bag, but do not preserve its state in ``state_dict``. :param name: @@ -115,10 +107,10 @@ def register_non_stateful(self, name: str, obj: Any) -> None: setattr(self, name, obj) @final - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, object]: state_dict = {} - state: Any + state: object for name, obj in self.__dict__.items(): if name in self._non_stateful_attrs: @@ -129,13 +121,13 @@ def state_dict(self) -> Dict[str, Any]: if is_explicit: if state_handler is None: if isinstance(obj, Stateful): - state = obj.state_dict() + state = self._state_dict(obj) else: state = obj else: state = state_handler.get_state(obj) elif isinstance(obj, Stateful) and not self._is_dunder(name): - state = obj.state_dict() + state = self._state_dict(obj) else: continue @@ -143,8 +135,28 @@ def state_dict(self) -> Dict[str, Any]: return state_dict + @staticmethod + def _state_dict(obj: Stateful) -> dict[str, object]: + if isinstance(obj, FSDP): + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + return obj.state_dict() + + return obj.state_dict() + @final - def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + def load_state_dict(self, state_dict: Mapping[str, object]) -> None: + def state_error(name: str, obj: object) -> ValueError: + return ValueError( + f"`state_dict['{name}']` is not a valid `{type(obj)}` state. See the nested exception for details." + ) + + def state_type_error(name: str, state: object) -> TypeError: + return TypeError( + f"`state_dict['{name}']` must be of type `{Mapping}`, but is of type `{type(state)}` instead." + ) + missing_stateful_attrs = [] state_dict_ = dict(state_dict) @@ -165,11 +177,20 @@ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: if state_handler is None: if isinstance(obj, Stateful): - obj.load_state_dict(state) + if not isinstance(state, Mapping): + raise state_type_error(name, state) + + try: + self._load_state_dict(obj, state) + except (ValueError, TypeError) as ex: + raise state_error(name, obj) from ex else: setattr(self, name, state) else: - state_handler.set_state(obj, state) + try: + state_handler.set_state(obj, state) + except (ValueError, TypeError) as ex: + raise state_error(name, obj) from ex elif isinstance(obj, Stateful) and not self._is_dunder(name): try: state = state_dict_.pop(name) @@ -178,25 +199,48 @@ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: continue - obj.load_state_dict(state) + if not isinstance(state, Mapping): + raise state_type_error(name, state) + + try: + self._load_state_dict(obj, state) + except (ValueError, TypeError) as ex: + raise state_error(name, obj) from ex if missing_stateful_attrs: missing_stateful_attrs.sort() + s = ", ".join(missing_stateful_attrs) + raise ValueError( - f"`state_dict` must contain the states of the following attributes: {', '.join(missing_stateful_attrs)}" + f"`state_dict` must contain the states of all of the following attribute(s): {s}" ) if state_dict_: - extra_keys = list(state_dict_.keys()) - - extra_keys.sort() + s = ", ".join(sorted(state_dict_.keys())) raise ValueError( - f"`state_dict` must only contain the states of the attributes of this object, but it contains the following extra keys: {', '.join(extra_keys)}" + f"`state_dict` must contain only the states of the attributes of this object, but it contains the following unexpected keys: {s}" ) - def _is_explicit(self, name: str) -> Tuple[bool, Optional[StateHandler[Any]]]: + @staticmethod + def _load_state_dict(obj: Stateful, state: Mapping[str, object]) -> None: + if isinstance(obj, FSDP): + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + load_state_dict(obj, state) + + return + + if isinstance(obj, Module): + load_state_dict(obj, state) + + return + + return obj.load_state_dict(state) + + def _is_explicit(self, name: str) -> tuple[bool, StateHandler[Any] | None]: try: state_handler = self._explicit_stateful_attrs[name] @@ -216,11 +260,11 @@ class StateHandler(ABC, Generic[StatefulT]): :class:`StatefulObjectBag`.""" @abstractmethod - def get_state(self, stateful: StatefulT) -> Any: + def get_state(self, stateful: StatefulT) -> object: """Get the state of ``stateful``.""" @abstractmethod - def set_state(self, stateful: StatefulT, state: Any) -> None: + def set_state(self, stateful: StatefulT, state: object) -> None: """Set the state of ``stateful`` to ``state``.""" @@ -239,25 +283,42 @@ def __init__(self, module: Module) -> None: self._module = module @override - def get_state(self, stateful: Optimizer) -> Any: - try: - # PyTorch 2.2 wrongfully uses warning level to dump a lot of noisy - # internal trace information. - logging.disable(logging.WARNING) - - return FSDP.optim_state_dict(self._module, stateful) - finally: - logging.disable(logging.NOTSET) + def get_state(self, stateful: Optimizer) -> object: + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + try: + # FSDP uses warning level to dump a lot of noisy internal trace + # information. + logging.disable(logging.WARNING) + + return FSDP.optim_state_dict(self._module, stateful) + except UnicodeDecodeError as ex: + raise RuntimeError( + "FSDP has failed to gather the optimizer state with a pickling error. This might indicate a disk space issue. Make sure you have enough space on your file system. See the nested exception for details." + ) from ex + finally: + logging.disable(logging.NOTSET) @override - def set_state(self, stateful: Optimizer, state: Any) -> None: - try: - # PyTorch 2.2 wrongfully uses warning level to dump a lot of noisy - # internal trace information. - logging.disable(logging.WARNING) + def set_state(self, stateful: Optimizer, state: object) -> None: + if not isinstance(state, dict): + raise TypeError( + f"`state` must be of type `dict`, but is of type `{type(state)}` instead." + ) + + with catch_warnings(): + warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. + + try: + # FSDP uses warning level to dump a lot of noisy internal trace + # information. + logging.disable(logging.WARNING) - state_dict = FSDP.optim_state_dict_to_load(self._module, stateful, state) - finally: - logging.disable(logging.NOTSET) + state_dict = FSDP.optim_state_dict_to_load( + self._module, stateful, state + ) + finally: + logging.disable(logging.NOTSET) - stateful.load_state_dict(state_dict) + stateful.load_state_dict(state_dict) diff --git a/src/fairseq2/utils/structured.py b/src/fairseq2/utils/structured.py new file mode 100644 index 000000000..9b1451540 --- /dev/null +++ b/src/fairseq2/utils/structured.py @@ -0,0 +1,752 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Mapping, Sequence, Set +from copy import deepcopy +from dataclasses import MISSING, fields, is_dataclass +from enum import Enum +from pathlib import Path +from types import NoneType, UnionType +from typing import ( + Any, + Literal, + Protocol, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, +) + +import torch + +from fairseq2.typing import DataClass, DataType, Device +from fairseq2.utils.dataclass import EMPTY + + +class _Structurer(Protocol): + def __call__( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> object: + ... + + +class _Unstructurer(Protocol): + def __call__(self, obj: object) -> object: + ... + + +class ValueConverter: + """Structures objects using provided type expressions.""" + + _structurers: dict[object, _Structurer] + _unstructurers: dict[type, _Unstructurer] + + def __init__(self) -> None: + self._structurers = { + # fmt: off + bool: self._structure_identity, + DataClass: self._structure_dataclass, + DataType: self._structure_dtype, + Device: self._structure_device, + dict: self._structure_dict, + float: self._structure_primitive, + Enum: self._structure_enum, + int: self._structure_primitive, + list: self._structure_list, + Literal: self._structure_literal, + Mapping: self._structure_dict, + NoneType: self._structure_identity, + Path: self._structure_path, + Sequence: self._structure_list, + set: self._structure_set, + Set: self._structure_set, + str: self._structure_identity, + tuple: self._structure_tuple, + Union: self._structure_union, + UnionType: self._structure_union, + # fmt: on + } + + self._unstructurers = { + # fmt: off + bool: self._unstructure_identity, + DataClass: self._unstructure_dataclass, + DataType: self._unstructure_dtype, + Device: self._unstructure_device, + float: self._unstructure_identity, + Enum: self._unstructure_enum, + int: self._unstructure_identity, + list: self._unstructure_sequence, + Mapping: self._unstructure_mapping, + NoneType: self._unstructure_identity, + Path: self._unstructure_path, + Set: self._unstructure_set, + str: self._unstructure_identity, + tuple: self._unstructure_sequence, + # fmt: on + } + + def structure(self, obj: object, type_: object, *, set_empty: bool = False) -> Any: + orig_type, type_args = get_origin(type_), get_args(type_) + + if orig_type is None: + orig_type = type_ + + if orig_type is object or orig_type is Any: + return obj + + lookup_type = orig_type + + if isinstance(orig_type, type): + if is_dataclass(orig_type): + lookup_type = DataClass + elif issubclass(orig_type, Enum): + lookup_type = Enum + + structurer = self._structurers.get(lookup_type) + if structurer is None: + supported_types = ", ".join(str(t) for t in self._structurers.keys()) + + raise StructureError( + f"`type_` must be the value of a type expression consisting of the following types, but is `{type_}` instead: {supported_types}" + ) from None + + try: + return structurer(orig_type, type_args, obj, set_empty) + except StructureError as ex: + raise StructureError( + f"`obj` cannot be structured to `{type_}`. See the nested exception for details." + ) from ex + + @staticmethod + def _structure_primitive( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> object: + kls = cast(type, orig_type) + + if isinstance(obj, kls): + return obj + + try: + return kls(obj) + except (TypeError, ValueError) as ex: + raise StructureError( + f"`obj` cannot be parsed as `{kls}`. See the nested exception for details." + ) from ex + + @staticmethod + def _structure_identity( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> object: + kls = cast(type, orig_type) + + if isinstance(obj, kls): + return obj + + raise StructureError( + f"`obj` must be of type `{kls}`, but is of type `{type(obj)}` instead." + ) + + def _structure_dataclass( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> object: + kls = cast(type[DataClass], orig_type) + + if kls is DataClass: + raise StructureError( + f"`type` must be a concrete dataclass type, but is `{DataClass}` instead." + ) + + if isinstance(obj, kls): + values = {f.name: getattr(obj, f.name) for f in fields(kls)} + + return self._make_dataclass(kls, values, set_empty) + + if isinstance(obj, Mapping): + values = self.structure(obj, dict[str, object]) + + return self._make_dataclass(kls, values, set_empty) + + raise StructureError( + f"`obj` must be of type `{kls}` or `{Mapping}`, but is of type `{type(obj)}` instead." + ) + + def _make_dataclass( + self, kls: type[DataClass], values: dict[str, object], set_empty: bool + ) -> object: + type_hints = get_type_hints(kls) + + kwargs = {} + + for field in fields(kls): + value = values.pop(field.name, EMPTY) + + # Fields with `init=False` are initialized in `__post_init__()`. + if not field.init: + continue + + if value is EMPTY: + if not set_empty: + if field.default == MISSING and field.default_factory == MISSING: + raise StructureError( + f"The `{field.name}` field of the dataclass has no default value or factory." + ) + + continue + + if hasattr(kls, "__post_init__"): + raise StructureError( + f"The `{field.name}` field of the dataclass must not be `EMPTY` since `{kls}` has a `__post_init__()` method." + ) + else: + try: + value = self.structure( + value, type_hints[field.name], set_empty=set_empty + ) + except StructureError as ex: + raise StructureError( + f"The `{field.name}` field of the dataclass cannot be structured. See the nested exception for details." + ) from ex + + kwargs[field.name] = value + + if values: + extra_keys = ", ".join(sorted(values.keys())) + + raise StructureError( + f"`obj` must contain only keys corresponding to the fields of `{kls}`, but it contains the following extra keys: {extra_keys}" + ) + + return kls(**kwargs) + + def _structure_dict( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> dict[object, object]: + if isinstance(obj, Mapping): + if len(type_args) != 2: + raise StructureError( + f"`type_` must have a key-value type annotation for `{orig_type}`." + ) + + output = {} + + for k, v in obj.items(): + k = self.structure(k, type_args[0]) + v = self.structure(v, type_args[1]) + + output[k] = v + + return output + + raise StructureError( + f"`obj` must be of type `{Mapping}`, but is of type `{type(obj)}` instead." + ) + + @staticmethod + def _structure_dtype( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> DataType: + if isinstance(obj, DataType): + return obj + + if isinstance(obj, str): + if obj.startswith("torch."): + obj = obj[6:] + + if isinstance(dtype := getattr(torch, obj, None), DataType): + return dtype + + raise StructureError( + f"`obj` must be a `torch.dtype` identifier, but is '{obj}' instead." + ) + + raise StructureError( + f"`obj` must be of type `{DataType}` or `{str}`, but is of type `{type(obj)}` instead." + ) + + @staticmethod + def _structure_device( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> Device: + if isinstance(obj, Device): + return obj + + if isinstance(obj, str): + try: + return Device(obj) + except RuntimeError as ex: + raise StructureError(str(ex)) + + raise StructureError( + f"`obj` must be of type `{Device}` or `{str}`, but is of type `{type(obj)}` instead." + ) + + @staticmethod + def _structure_enum( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> object: + kls = cast(type[Enum], orig_type) + + if isinstance(obj, kls): + return obj + + if isinstance(obj, str): + try: + return kls[obj] + except KeyError: + pass + + values = ", ".join(e.name for e in kls) + + raise StructureError( + f"`obj` must be one of the following enumeration values, but is '{obj}' instead: {values}" + ) from None + + raise StructureError( + f"`obj` must be of type `{kls}` or `{str}`, but is of type `{type(obj)}` instead." + ) + + def _structure_list( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> list[object]: + if isinstance(obj, Sequence): + if len(type_args) != 1: + raise StructureError( + f"`type_` must have an element type annotation for `{orig_type}`." + ) + + output = [] + + for idx, elem in enumerate(obj): + try: + elem = self.structure(elem, type_args[0]) + except StructureError as ex: + raise StructureError( + f"The element at index {idx} in the sequence cannot be structured. See the nested exception for details." + ) from ex + + output.append(elem) + + return output + + raise StructureError( + f"`obj` must be of type `{Sequence}`, but is of type `{type(obj)}` instead." + ) + + @staticmethod + def _structure_literal( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> str: + if isinstance(obj, str): + if obj in type_args: + return obj + + values = ", ".join(str(t) for t in type_args) + + raise StructureError( + f"`obj` must be one of the following values, but is '{obj}' instead: {values}" + ) + + raise StructureError( + f"`obj` must be of type `{str}`, but is of type `{type(obj)}` instead." + ) + + @staticmethod + def _structure_path( + orig_type: object, type_args: tuple[object, ...], obj: object, set_empty: bool + ) -> Path: + if isinstance(obj, Path): + return obj + + if isinstance(obj, str): + return Path(obj) + + raise StructureError( + f"`obj` must be of type `{Path}` or `{str}`, but is of type `{type(obj)}` instead." + ) + + def _structure_set( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> set[object]: + if isinstance(obj, set): + if len(type_args) != 1: + raise StructureError( + f"`type_` must have an element type annotation for `{orig_type}`." + ) + + return {self.structure(e, type_args[0]) for e in obj} + + if isinstance(obj, Sequence): + if len(type_args) != 1: + raise StructureError( + f"`type_` must have an element type annotation for `{orig_type}`." + ) + + tmp = [self.structure(e, type_args[0]) for e in obj] + + output = set(tmp) + + if len(output) != len(tmp): + raise StructureError( + f"All elements of `obj` must be unique to be treated as a `{set}`." + ) + + return output + + raise StructureError( + f"`obj` must be of type `{set}` or `{Sequence}`, but is of type `{type(obj)}` instead." + ) + + def _structure_tuple( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> tuple[object, ...]: + if isinstance(obj, Sequence): + num_args = len(type_args) + + if num_args == 0: + raise StructureError( + f"`type_` must have an element type annotation for `{orig_type}`." + ) + + if num_args == 2 and type_args[1] is Ellipsis: # homogeneous + tmp = self._structure_list(orig_type, type_args[:1], obj, set_empty) + + return tuple(tmp) + + if len(obj) != num_args: # heterogeneous + raise StructureError( + f"`obj` must be have {num_args} element(s), but it has {len(obj)} element(s)." + ) + + output = [] + + for idx, elem in enumerate(obj): + try: + elem = self.structure(elem, type_args[idx]) + except StructureError as ex: + raise StructureError( + f"The element at index {idx} in the sequence cannot be structured. See the nested exception for details." + ) from ex + + output.append(elem) + + return tuple(output) + + raise StructureError( + f"`obj` must be of type `{tuple}` or `{Sequence}`, but is of type `{type(obj)}` instead." + ) + + def _structure_union( + self, + orig_type: object, + type_args: tuple[object, ...], + obj: object, + set_empty: bool, + ) -> object: + is_optional = len(type_args) == 2 and NoneType in type_args + + if is_optional and obj is None: + return obj + + for type_ in type_args: + try: + return self.structure(obj, type_, set_empty=set_empty) + except StructureError: + if is_optional: + raise + + continue + + types = ", ".join(str(t) for t in type_args) + + raise StructureError( + f"`obj` must be parseable as one of the following union elements: {types}" + ) + + def unstructure(self, obj: object) -> object: + kls = type(obj) + + lookup_kls: type + + if is_dataclass(kls): + lookup_kls = DataClass + elif issubclass(kls, Enum): + lookup_kls = Enum + elif issubclass(kls, Mapping): + lookup_kls = Mapping + elif issubclass(kls, Path): + lookup_kls = Path + elif issubclass(kls, Set): + lookup_kls = Set + else: + lookup_kls = kls + + unstructurer = self._unstructurers.get(lookup_kls) + if unstructurer is None: + supported_types = ", ".join(str(t) for t in self._unstructurers.keys()) + + raise StructureError( + f"`obj` must be of one of the following types, but is of type `{type(obj)}` instead: {supported_types}" + ) from None + + try: + return unstructurer(obj) + except StructureError as ex: + raise StructureError( + "`obj` cannot be unstructured. See the nested exception for details." + ) from ex + + @staticmethod + def _unstructure_identity(obj: object) -> object: + return obj + + def _unstructure_dataclass(self, obj: object) -> dict[str, object]: + d = cast(DataClass, obj) + + kls = type(d) + + output: dict[str, object] = {} + + for field in fields(kls): + value = getattr(obj, field.name) + + try: + output[field.name] = self.unstructure(value) + except StructureError as ex: + raise StructureError( + f"The `{field.name}` field of the dataclass cannot be unstructured. See the nested exception for details." + ) from ex + + return output + + @staticmethod + def _unstructure_dtype(obj: object) -> str: + return str(obj)[6:] # strip 'torch.' + + @staticmethod + def _unstructure_device(obj: object) -> str: + return str(obj) + + @staticmethod + def _unstructure_enum(obj: object) -> str: + return cast(Enum, obj).name + + def _unstructure_mapping(self, obj: object) -> dict[object, object]: + output = {} + + m = cast(Mapping[object, object], obj) + + for k, v in m.items(): + k = self.unstructure(k) + v = self.unstructure(v) + + output[k] = self.unstructure(v) + + return output + + @staticmethod + def _unstructure_path(obj: object) -> str: + return str(obj) + + def _unstructure_sequence(self, obj: object) -> list[object]: + output = [] + + s = cast(Sequence[object], obj) + + for idx, elem in enumerate(s): + try: + elem = self.unstructure(elem) + except StructureError as ex: + raise StructureError( + f"The element at index {idx} in the sequence cannot be unstructured. See the nested exception for details." + ) from ex + + output.append(elem) + + return output + + def _unstructure_set(self, obj: object) -> list[object]: + s = cast(set[object], obj) + + return [self.unstructure(e) for e in s] + + +default_value_converter = ValueConverter() + + +T = TypeVar("T") + + +def structure(obj: object, kls: type[T], *, set_empty: bool = False) -> T: + obj = default_value_converter.structure(obj, kls, set_empty=set_empty) + + return cast(T, obj) + + +def unstructure(obj: object) -> object: + return default_value_converter.unstructure(obj) + + +def is_unstructured(obj: object) -> bool: + if isinstance(obj, dict): + for k, v in obj.items(): + if not is_unstructured(k): + return False + + if not is_unstructured(v): + return False + + return True + + if isinstance(obj, list): + for e in obj: + if not is_unstructured(e): + return False + + return True + + return isinstance(obj, NoneType | bool | int | float | str) + + +def merge_unstructured(target: object, source: object) -> object: + def type_error(param_name: str) -> StructureError: + return StructureError( + f"`{param_name}` must be of a composition of types `bool`, `int`, `float`, `str`, `list`, and `dict`." + ) + + if not is_unstructured(target): + raise type_error("target") + + if not is_unstructured(source): + raise type_error("source") + + return _do_merge_unstructured(target, source, "") + + +def _do_merge_unstructured(target: object, source: object, path: str) -> object: + if isinstance(source, dict): + if not isinstance(target, dict): + target = {} + + sep = "." if path else "" + + output = {} + + ignored_keys = set() + + del_keys = source.get("_del_") + if del_keys is not None: + if not isinstance(del_keys, list): + raise StructureError( + f"'{path}{sep}_del_' in `source` must be of type `list`, but is of type `{type(del_keys).__name__}` instead." + ) + + for idx, del_key in enumerate(del_keys): + if not isinstance(del_key, str): + raise StructureError( + f"Each element under '{path}{sep}_del_' in `source` must be of type `str`, but the element at index {idx} is of type `{type(del_key).__name__}` instead." + ) + + ignored_keys.add(del_key) + + for k, v in target.items(): + if k not in ignored_keys: + output[k] = deepcopy(v) + + add_keys = source.get("_add_") + if add_keys is not None: + if not isinstance(add_keys, dict): + raise StructureError( + f"'{path}{sep}_add_' in `source` must be of type `dict`, but is of type `{type(add_keys).__name__}` instead." + ) + + for idx, (add_key, value) in enumerate(add_keys.items()): + if not isinstance(add_key, str): + raise StructureError( + f"Each key under '{path}{sep}_add_' in `source` must be of type `str`, but the key at index {idx} is of type `{type(add_key).__name__}` instead." + ) + + output[add_key] = deepcopy(value) + + set_keys = source.get("_set_") + if set_keys is not None: + if not isinstance(set_keys, dict): + raise StructureError( + f"'{path}{sep}_set_' in `source` must be of type `dict`, but is of type `{type(set_keys).__name__}` instead." + ) + + for idx, (set_key, value) in enumerate(set_keys.items()): + if not isinstance(set_key, str): + raise StructureError( + f"Each key under '{path}{sep}_set_' in `source` must be of type `str`, but the key at index {idx} is of type `{type(set_key).__name__}` instead." + ) + + if set_key not in output: + sub_path = set_key if not path else f"{path}.{set_key}" + + raise StructureError( + f"`target` must contain a '{sub_path}' key since it exists in `source`." + ) from None + + output[set_key] = deepcopy(value) + + for key, source_value in source.items(): + if key == "_del_" or key == "_add_" or key == "_set_": + continue + + # Maintain backwards compatibility with older configuration API. + if key == "_type_": + continue + + sub_path = key if not path else f"{path}.{key}" + + try: + target_value = output[key] + except KeyError: + raise StructureError( + f"`target` must contain a '{sub_path}' key since it exists in `source`." + ) from None + + output[key] = _do_merge_unstructured(target_value, source_value, sub_path) + + return output + + if isinstance(source, list | dict): + return deepcopy(source) + + return source + + +class StructureError(ValueError): + """Raised when a structure or unstructure operation fails.""" diff --git a/src/fairseq2/utils/value_converter.py b/src/fairseq2/utils/value_converter.py deleted file mode 100644 index 93daaf296..000000000 --- a/src/fairseq2/utils/value_converter.py +++ /dev/null @@ -1,359 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import sys -from enum import Enum -from pathlib import Path, PosixPath -from typing import ( - Any, - Callable, - Dict, - Literal, - Sequence, - Set, - Type, - Union, - cast, - get_args, - get_origin, -) - -import torch - -from fairseq2.typing import DataType, Device - - -class ValueConverter: - """Structures and unstructures objects using provided type hints.""" - - _structure_fns: Dict[object, Callable[[Any, Any, Any], Any]] - _unstructure_fns: Dict[object, Callable[[Any, Any], Any]] - - def __init__(self) -> None: - NoneType = type(None) - - self._structure_fns = { - # fmt: off - bool: self._structure_identity, - DataType: self._structure_dtype, - Device: self._structure_device, - dict: self._structure_dict, - float: self._structure_primitive, - Enum: self._structure_enum, - int: self._structure_primitive, - list: self._structure_list, - Literal: self._structure_literal, - NoneType: self._structure_identity, - Path: self._structure_path, - PosixPath: self._structure_path, - set: self._structure_set, - str: self._structure_identity, - tuple: self._structure_tuple, - Union: self._structure_union, - # fmt: on - } - - self._unstructure_fns = { - # fmt: off - bool: self._unstructure_identity, - DataType: self._unstructure_dtype, - Device: self._unstructure_device, - dict: self._unstructure_dict, - float: self._unstructure_identity, - Enum: self._unstructure_enum, - int: self._unstructure_identity, - list: self._unstructure_sequence, - NoneType: self._unstructure_identity, - Path: self._unstructure_path, - PosixPath: self._unstructure_path, - set: self._unstructure_set, - str: self._unstructure_identity, - tuple: self._unstructure_sequence, - # fmt: on - } - - if sys.version_info >= (3, 10): - from types import UnionType - - # Unions types in PEP 604 (i.e. pipe) syntax are represented by - # `types.UnionType`. - self._structure_fns[UnionType] = self._structure_union - - def structure(self, obj: Any, type_hint: Any) -> Any: - """ - :param obj: - The object to structure based on ``type_hint``. - :param type_hint: - The type hint. Typically retrieved via ``typing.get_type_hints()``. - """ - kls, kls_args = get_origin(type_hint), get_args(type_hint) - - if kls is None: - kls = type_hint - - if kls is Any: - return obj - - if isinstance(kls, type): - lookup_kls = Enum if issubclass(kls, Enum) else kls - else: - lookup_kls = kls # typing special form - - try: - fn = self._structure_fns[lookup_kls] - except KeyError: - supported = ", ".join(str(t) for t in self._structure_fns.keys()) - - raise ValueError( - f"`type_hint` of `obj` must be of one of the following, but is `{type_hint}` instead: {supported}" - ) from None - - try: - return fn(kls, kls_args, obj) - except (TypeError, ValueError) as ex: - raise TypeError( - f"`obj` cannot be structured to type `{type_hint}`. See nested exception for details." - ) from ex - - @staticmethod - def _structure_identity(kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, kls): - return obj - - raise TypeError( - f"`obj` must be of type `{kls}`, but is of type `{type(obj)}` instead." - ) - - @staticmethod - def _structure_primitive(kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, kls): - return obj - - return kls(obj) - - @staticmethod - def _structure_dtype(kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, DataType): - return obj - - if isinstance(obj, str): - if obj.startswith("torch."): - obj = obj[6:] - - if isinstance(dtype := getattr(torch, obj, None), DataType): - return dtype - - raise ValueError( - f"`obj` must be a `torch.dtype` identifier, but is '{obj}' instead." - ) - - raise TypeError( - f"`obj` must be of type `{DataType}` or `{str}`, but is of type `{type(obj)}` instead." - ) - - @staticmethod - def _structure_device(kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, Device): - return obj - - if isinstance(obj, str): - try: - return Device(obj) - except RuntimeError as ex: - raise ValueError(str(ex)) - - raise TypeError( - f"`obj` must be of type `{Device}` or `{str}`, but is of type `{type(obj)}` instead." - ) - - def _structure_dict(self, kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, dict): - if len(kls_args) != 2: - raise TypeError("`type_hint` has no type annotation for `dict`.") - - output = {} - - for k, v in obj.items(): - k = self.structure(k, kls_args[0]) - v = self.structure(v, kls_args[1]) - - output[k] = v - - return output - - raise TypeError( - f"`obj` must be of type `{dict}`, but is of type `{type(obj)}` instead." - ) - - @staticmethod - def _structure_enum(kls: Any, kls_args: Any, obj: Any) -> Any: - enum_kls = cast(Type[Enum], kls) - - if isinstance(obj, enum_kls): - return obj - - if isinstance(obj, str): - try: - return enum_kls[obj] # type: ignore[index] - except KeyError: - raise ValueError( - f"`obj` must be one of the following enumeration values, but is '{obj}' instead: {', '.join(e.name for e in enum_kls)}." - ) from None - - raise TypeError( - f"`obj` must be of type `{kls}` or `{str}`, but is of type `{type(obj)}` instead." - ) - - def _structure_list(self, kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, Sequence): - if len(kls_args) != 1: - raise TypeError("`type_hint` has no type annotation for `list`.") - - return [self.structure(e, kls_args[0]) for e in obj] - - raise TypeError( - f"`obj` must be of type `{Sequence}`, but is of type `{type(obj)}` instead." - ) - - @staticmethod - def _structure_literal(kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, str): - if obj in kls_args: - return obj - - raise ValueError( - f"`obj` must be one of the following values, but is '{obj}' instead: {', '.join(kls_args)}." - ) - - raise TypeError( - f"`obj` must be of type `{str}`, but is of type `{type(obj)}` instead." - ) - - @staticmethod - def _structure_path(kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, Path): - return obj - - if isinstance(obj, str): - return Path(obj) - - raise TypeError( - f"`obj` must be of type `{Path}` or `{str}`, but is of type `{type(obj)}` instead." - ) - - def _structure_set(self, kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, Set): - if len(kls_args) != 1: - raise TypeError("`type_hint` has no type annotation for `set`.") - - return {self.structure(e, kls_args[0]) for e in obj} - - if isinstance(obj, Sequence): - if len(kls_args) != 1: - raise TypeError("`type_hint` has no type annotation for `set`.") - - tmp = [self.structure(e, kls_args[0]) for e in obj] - - output = set(tmp) - - if len(output) != len(tmp): - raise ValueError( - f"All elements of `obj` must be unique to be treated as a `{set}`." - ) - - return output - - raise TypeError( - f"`obj` must be of type `{set}` or `{Sequence}`, but is of type `{type(obj)}` instead." - ) - - def _structure_tuple(self, kls: Any, kls_args: Any, obj: Any) -> Any: - if isinstance(obj, Sequence): - num_args = len(kls_args) - - if num_args == 0: - raise TypeError("`type_hint` has no type annotation for `tuple`.") - - if num_args == 2 and kls_args[1] is Ellipsis: # homogeneous - tmp = [self.structure(e, kls_args[0]) for e in obj] - - return tuple(tmp) - - if len(obj) != num_args: # heterogeneous - raise TypeError( - f"`obj` must be have {num_args} elements, but it has {len(obj)} elements." - ) - - output = [] - - for i, e in enumerate(obj): - output.append(self.structure(e, kls_args[i])) - - return tuple(output) - - raise TypeError( - f"`obj` must be of type `{tuple}` or `{Sequence}`, but is of type `{type(obj)}` instead." - ) - - def _structure_union(self, kls: Any, kls_args: Any, obj: Any) -> Any: - for kls_ in kls_args: - try: - return self.structure(obj, kls_) - except (TypeError, ValueError): - continue - - raise TypeError( - f"`obj` must be parseable as one of the following union types: {', '.join(str(t) for t in kls_args)}" - ) - - def unstructure(self, obj: Any) -> Any: - kls = type(obj) - - lookup_kls = Enum if issubclass(kls, Enum) else kls - - try: - fn = self._unstructure_fns[lookup_kls] - except KeyError: - supported_types = ", ".join(str(t) for t in self._unstructure_fns.keys()) - - raise TypeError( - f"`obj` must be of one of the following types, but is of type `{type(obj)}` instead: {supported_types}" - ) from None - - return fn(kls, obj) - - @staticmethod - def _unstructure_identity(kls: Any, obj: Any) -> Any: - return obj - - @staticmethod - def _unstructure_dtype(kls: Any, obj: Any) -> Any: - return str(obj)[6:] # strip 'torch.' - - @staticmethod - def _unstructure_device(kls: Any, obj: Any) -> Any: - return str(obj) - - def _unstructure_dict(self, kls: Any, obj: Any) -> Any: - return {self.unstructure(k): self.unstructure(v) for k, v in obj.items()} - - def _unstructure_enum(self, kls: Any, obj: Any) -> Any: - return obj.name - - def _unstructure_set(self, kls: Any, obj: Any) -> Any: - return [self.unstructure(e) for e in obj] - - def _unstructure_sequence(self, kls: Any, obj: Any) -> Any: - return [self.unstructure(e) for e in obj] - - @staticmethod - def _unstructure_path(kls: Any, obj: Any) -> Any: - return str(obj) - - -default_value_converter = ValueConverter() diff --git a/src/fairseq2/utils/version.py b/src/fairseq2/utils/version.py index 48c98b830..ba3d787e6 100644 --- a/src/fairseq2/utils/version.py +++ b/src/fairseq2/utils/version.py @@ -24,8 +24,6 @@ def _get_torch_version() -> Version: def torch_greater_or_equal(major: int, minor: int) -> bool: - """Return ``True`` if the installed version of PyTorch is greater than or - equal to the specified major-minor version.""" if TORCH_VERSION.major <= major - 1: return False diff --git a/src/fairseq2/utils/yaml.py b/src/fairseq2/utils/yaml.py new file mode 100644 index 000000000..341cf6839 --- /dev/null +++ b/src/fairseq2/utils/yaml.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from pathlib import Path +from typing import IO, Protocol, TypeAlias + +import yaml +from yaml import YAMLError + + +class YamlLoader(Protocol): + def __call__(self, input_: Path | IO[str]) -> list[object]: + ... + + +class YamlDumper(Protocol): + def __call__(self, obj: object, output: Path | IO[str]) -> None: + ... + + +YamlError: TypeAlias = YAMLError + + +def load_yaml(input_: Path | IO[str]) -> list[object]: + if isinstance(input_, Path): + with input_.open() as fp: + return load_yaml(fp) + + itr = yaml.safe_load_all(input_) + + return list(itr) + + +def dump_yaml(obj: object, output: Path | IO[str]) -> None: + if isinstance(output, Path): + with output.open("w") as fp: + dump_yaml(obj, fp) + else: + yaml.safe_dump(obj, output, sort_keys=False) + + +def read_yaml(s: str) -> object: + return yaml.safe_load(s) diff --git a/tests/common.py b/tests/common.py index 66dd299c7..5fd95f679 100644 --- a/tests/common.py +++ b/tests/common.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, List, Union +from typing import Any import torch from torch import Tensor @@ -18,7 +18,7 @@ device = Device("cpu") -def assert_close(a: Tensor, b: Union[Tensor, List[Any]]) -> None: +def assert_close(a: Tensor, b: Tensor | list[Any]) -> None: """Assert that ``a`` and ``b`` are element-wise equal within a tolerance.""" if not isinstance(b, Tensor): b = torch.tensor(b, device=device, dtype=a.dtype) @@ -26,7 +26,7 @@ def assert_close(a: Tensor, b: Union[Tensor, List[Any]]) -> None: torch.testing.assert_close(a, b) # type: ignore[attr-defined] -def assert_equal(a: Tensor, b: Union[Tensor, List[Any]]) -> None: +def assert_equal(a: Tensor, b: Tensor | list[Any]) -> None: """Assert that ``a`` and ``b`` are element-wise equal.""" if not isinstance(b, Tensor): b = torch.tensor(b, device=device, dtype=a.dtype) diff --git a/tests/integration/generation/test_incremental_decode.py b/tests/integration/generation/test_incremental_decode.py index 398235d40..620921abe 100644 --- a/tests/integration/generation/test_incremental_decode.py +++ b/tests/integration/generation/test_incremental_decode.py @@ -8,7 +8,7 @@ import torch -from fairseq2.models.nllb import load_nllb_tokenizer +from fairseq2.data.text import load_text_tokenizer from fairseq2.models.transformer import load_transformer_model from fairseq2.nn import IncrementalStateBag from fairseq2.nn.padding import pad_seqs @@ -29,7 +29,7 @@ def test_incremental_decoding_works() -> None: model.eval() - tokenizer = load_nllb_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) pad_idx = tokenizer.vocab_info.pad_idx diff --git a/tests/integration/generation/test_sampling.py b/tests/integration/generation/test_sampling.py index 8bfc3e331..ffa3bdbba 100644 --- a/tests/integration/generation/test_sampling.py +++ b/tests/integration/generation/test_sampling.py @@ -8,8 +8,8 @@ import torch +from fairseq2.data.text import load_text_tokenizer from fairseq2.generation import SamplingSeq2SeqGenerator, TextTranslator, TopKSampler -from fairseq2.models.nllb import load_nllb_tokenizer from fairseq2.models.transformer import load_transformer_model from tests.common import device @@ -26,7 +26,7 @@ def test_greedy_sampling() -> None: model_name, device=device, dtype=torch.float32, progress=False ) - tokenizer = load_nllb_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) sampler = TopKSampler(k=1) diff --git a/tests/integration/generation/test_step_processor.py b/tests/integration/generation/test_step_processor.py index f3ce4da80..66801784b 100644 --- a/tests/integration/generation/test_step_processor.py +++ b/tests/integration/generation/test_step_processor.py @@ -4,19 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import ClassVar, Final, Sequence +from collections.abc import Sequence +from typing import ClassVar, Final import pytest import torch -from fairseq2.data.text import TextTokenizer +from fairseq2.data.text import TextTokenizer, load_text_tokenizer from fairseq2.generation import ( BannedSequenceProcessor, BeamSearchSeq2SeqGenerator, TextTranslator, ) from fairseq2.models.encoder_decoder import EncoderDecoderModel -from fairseq2.models.nllb import load_nllb_tokenizer from fairseq2.models.transformer import load_transformer_model from tests.common import device @@ -69,7 +69,7 @@ def setup_class(cls) -> None: model_name, device=device, dtype=torch.float32, progress=False ) - cls.tokenizer = load_nllb_tokenizer(model_name, progress=False) + cls.tokenizer = load_text_tokenizer(model_name) @classmethod def teardown_class(cls) -> None: diff --git a/tests/integration/models/test_llama.py b/tests/integration/models/test_llama.py index cf3286896..633d44589 100644 --- a/tests/integration/models/test_llama.py +++ b/tests/integration/models/test_llama.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. import os +from typing import cast import pytest -from fairseq2.assets import default_asset_store, default_download_manager +from fairseq2.context import get_runtime_context from fairseq2.models.llama import create_llama_model, llama_archs from fairseq2.models.llama.integ import convert_to_reference_checkpoint from fairseq2.models.llama.loader import convert_llama_checkpoint @@ -21,11 +22,13 @@ "FAIR_ENV_CLUSTER" not in os.environ, reason="checkpoints only on faircluster" ) def test_convert_to_reference_checkpoint() -> None: + context = get_runtime_context() + model_config = llama_archs.get("llama2_7b") - card = default_asset_store.retrieve_card("llama2_7b") + card = context.asset_store.retrieve_card("llama2_7b") - path = default_download_manager.download_checkpoint( + path = context.asset_download_manager.download_checkpoint( card.field("checkpoint").as_uri(), model_name="llama2_7b", progress=False ) @@ -44,4 +47,4 @@ def test_convert_to_reference_checkpoint() -> None: model = create_llama_model(model_config, device=device) # This should work. - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(cast(dict[str, object], checkpoint["model"])) diff --git a/tests/integration/models/test_nllb.py b/tests/integration/models/test_nllb.py index e028c4e8c..7ebccc4c6 100644 --- a/tests/integration/models/test_nllb.py +++ b/tests/integration/models/test_nllb.py @@ -9,8 +9,8 @@ import pytest import torch +from fairseq2.data.text import load_text_tokenizer from fairseq2.generation import BeamSearchSeq2SeqGenerator, TextTranslator -from fairseq2.models.nllb import load_nllb_tokenizer from fairseq2.models.transformer import load_transformer_model from tests.common import device @@ -25,7 +25,7 @@ def test_load_dense_distill_600m() -> None: model_name, device=device, dtype=torch.float32, progress=False ) - tokenizer = load_nllb_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) generator = BeamSearchSeq2SeqGenerator(model, echo_prompt=True, max_seq_len=128) @@ -56,7 +56,7 @@ def test_load_dense_distill_600m() -> None: def test_tokenizer_special_tokens() -> None: model_name = "nllb-200_dense_distill_600m" - tokenizer = load_nllb_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) text = "Hello world!" diff --git a/tests/integration/models/test_s2t_transformer.py b/tests/integration/models/test_s2t_transformer.py index f07258541..1fdfbedac 100644 --- a/tests/integration/models/test_s2t_transformer.py +++ b/tests/integration/models/test_s2t_transformer.py @@ -9,12 +9,9 @@ import torch +from fairseq2.data.text import TextTokenizer, load_text_tokenizer from fairseq2.generation import BeamSearchSeq2SeqGenerator, SequenceToTextConverter -from fairseq2.models.s2t_transformer import ( - S2TTransformerTokenizer, - load_s2t_transformer_model, - load_s2t_transformer_tokenizer, -) +from fairseq2.models.s2t_transformer import load_s2t_transformer_model from fairseq2.models.transformer import TransformerModel from tests.common import device @@ -34,7 +31,7 @@ def test_load_s2t_transformer_mustc_st_jt_m() -> None: model_name, device=device, dtype=torch.float32, progress=False ) - tokenizer = load_s2t_transformer_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) assert_translation(model, tokenizer, expected=TRANSFORMER_DE) @@ -46,7 +43,7 @@ def test_load_s2t_conformer_covost_st_en_de() -> None: model_name, device=device, dtype=torch.float32, progress=False ) - tokenizer = load_s2t_transformer_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) assert_translation(model, tokenizer, expected=CONFORMER_DE) @@ -58,15 +55,15 @@ def test_load_s2t_conformer_rel_pos_covost_st_en_de() -> None: model_name, device=device, dtype=torch.float32, progress=False ) - tokenizer = load_s2t_transformer_tokenizer(model_name, progress=False) + tokenizer = load_text_tokenizer(model_name) assert_translation(model, tokenizer, expected=CONFORMER_DE_REL_POS) def assert_translation( - model: TransformerModel, tokenizer: S2TTransformerTokenizer, expected: str + model: TransformerModel, tokenizer: TextTokenizer, expected: str ) -> None: - fbank = torch.load(TEST_FBANK_PATH).to(device) + fbank = torch.load(TEST_FBANK_PATH, weights_only=True).to(device) generator = BeamSearchSeq2SeqGenerator(model) diff --git a/tests/integration/parquet/test_parquet_dataloader.py b/tests/integration/parquet/test_parquet_dataloader.py index 846796f2a..8cff04c0d 100644 --- a/tests/integration/parquet/test_parquet_dataloader.py +++ b/tests/integration/parquet/test_parquet_dataloader.py @@ -10,7 +10,8 @@ import string import tempfile from collections import Counter -from typing import Any, Dict, Generator, List, Union +from collections.abc import Generator +from typing import Any import pytest @@ -38,7 +39,7 @@ def gen_random_string(length: int) -> str: def generate_random_pandas_df(size: int, seed: int = 123) -> pd.DataFrame: np_rs = np.random.RandomState(seed) - df: Dict[str, Union[NDArray[Any], List[Any]]] = {} + df: dict[str, NDArray[Any] | list[Any]] = {} df["int_col"] = np_rs.randint(0, 200, size) df["float_col"] = np_rs.randn(size) @@ -104,7 +105,7 @@ def test_simple_dataload(self, multi_partition_file: str) -> None: nb_parallel_fragments=2, seed=333, ) - res: List[pd.DataFrame] = list(parquet_iterator(config)) + res: list[pd.DataFrame] = list(parquet_iterator(config)) assert all(isinstance(x, pa.Table) for x in res) @@ -162,7 +163,7 @@ def test_filtered_with_columns_dataload(self, multi_partition_file: str) -> None output_format=ParquetBatchFormat.pandas, ) - res: List[pd.DataFrame] = list(parquet_iterator(config)) + res: list[pd.DataFrame] = list(parquet_iterator(config)) assert list(res[0].columns) == ["string_col2", "list_int_col", "float_col"] @@ -193,7 +194,7 @@ def test_ordered_dataload(self, multi_partition_file: str) -> None: seed=123, output_format=ParquetBatchFormat.pandas, ) - res: List[pd.DataFrame] = list(parquet_iterator(config)) + res: list[pd.DataFrame] = list(parquet_iterator(config)) length_by_batches = [tt["list_int_col"].apply(len) for tt in res] length_by_batches_diff = max(tt.max() - tt.min() for tt in length_by_batches) total_length = sum(map(len, length_by_batches)) @@ -211,7 +212,7 @@ def test_ordered_max_token_dataload(self, multi_partition_file: str) -> None: seed=123, output_format=ParquetBatchFormat.pandas, ) - res: List[pd.DataFrame] = list(parquet_iterator(config)) + res: list[pd.DataFrame] = list(parquet_iterator(config)) length_by_batches = [tt["list_int_col"].apply(len) for tt in res] length_by_batches_diff = max(tt.max() - tt.min() for tt in length_by_batches) max_padded_total_length = max(tt.max() * len(tt) for tt in length_by_batches) @@ -232,7 +233,7 @@ def test_ordered_max_token_single_file_dataload(self, single_file: str) -> None: batch_size=10, seed=333, ) - res: List[pa.Table] = list(parquet_iterator(config)) + res: list[pa.Table] = list(parquet_iterator(config)) assert Counter(map(len, res)) == Counter({10: 100}) diff --git a/tests/unit/assets/test_card.py b/tests/unit/assets/test_card.py index 862bf1d75..827c09201 100644 --- a/tests/unit/assets/test_card.py +++ b/tests/unit/assets/test_card.py @@ -7,7 +7,6 @@ from __future__ import annotations from pathlib import Path -from typing import Dict, List, Optional, Set import pytest @@ -21,18 +20,18 @@ def setup_method(self) -> None: "field3": 3, } - root_card = AssetCard(root_metadata) + root_card = AssetCard("root-card", root_metadata) - base_metadata = { + base_metadata: dict[str, object] = { "name": "base-card", "field1": "base-foo1", "field2": {"sub-field1": "sub-foo1"}, "field8": [1, "b", 3], } - base_card = AssetCard(base_metadata, root_card) + base_card = AssetCard("base-card", base_metadata, root_card) - metadata = { + metadata: dict[str, object] = { "name": "test-card", "field1": "foo1", "field2": { @@ -46,7 +45,7 @@ def setup_method(self) -> None: "field10": None, } - self.card = AssetCard(metadata, base_card) + self.card = AssetCard("test-card", metadata, base_card) def test_field_works(self) -> None: value = self.card.field("field1").as_(str) @@ -65,93 +64,85 @@ def test_field_works(self) -> None: assert int_value == 3 - none_value = self.card.field("field10").as_(Optional[str]) + none_value = self.card.field("field10").as_(str | None) assert none_value is None def test_as_raises_error_when_field_type_is_incorrect(self) -> None: with pytest.raises( - AssetCardError, - match=rf"^The value of the field 'field1' of the asset card 'test-card' cannot be retrieved as `{int}`. See nested exception for details\.$", + AssetCardError, match=rf"^The value of the 'field1' field of the 'test-card' asset card cannot be parsed as `{int}`. See the nested exception for details\.$", # fmt: skip ): self.card.field("field1").as_(int) with pytest.raises( - AssetCardError, - match=rf"^The value of the field 'field2\.sub-field1' of the asset card 'test-card' cannot be retrieved as `{int}`. See nested exception for details\.$", + AssetCardError, match=rf"^The value of the 'field2\.sub-field1' field of the 'test-card' asset card cannot be parsed as `{int}`. See the nested exception for details\.$", # fmt: skip ): self.card.field("field2").field("sub-field1").as_(int) def test_as_raises_error_when_field_does_not_exist(self) -> None: with pytest.raises( - AssetCardFieldNotFoundError, - match=r"^The asset card 'test-card' must have a field named 'field11'\.$", + AssetCardFieldNotFoundError, match=r"^The 'test-card' asset card does not have a field named 'field11'\.$", # fmt: skip ): self.card.field("field11").as_(str) with pytest.raises( - AssetCardError, - match=r"^The asset card 'test-card' must have a field named 'field10\.sub-field'\.$", + AssetCardFieldNotFoundError, match=r"^The 'test-card' asset card does not have a field named 'field10\.sub-field'\.$", # fmt: skip ): self.card.field("field10").field("sub-field").as_(str) def test_as_raises_error_when_field_is_empty(self) -> None: with pytest.raises( AssetCardError, - match=r"^The value of the field 'field4' of the asset card 'test-card' must not be empty\.$", + match=r"^The value of the 'field4' field of the 'test-card' asset card is empty\.$", ): self.card.field("field4").as_(str) with pytest.raises( - AssetCardError, - match=r"^The value of the field 'field5' of the asset card 'test-card' must not be empty\.$", + AssetCardError, match=r"^The value of the 'field5' field of the 'test-card' asset card is empty\.$", # fmt: skip ): - self.card.field("field5").as_(List[str]) + self.card.field("field5").as_(list[str]) def test_as_works_when_allow_empty_is_true(self) -> None: value1 = self.card.field("field4").as_(str, allow_empty=True) assert value1 == "" - value2 = self.card.field("field5").as_(List[str], allow_empty=True) + value2 = self.card.field("field5").as_(list[str], allow_empty=True) assert value2 == [] def test_as_list_works(self) -> None: - value = self.card.field("field7").as_(List[int]) + value = self.card.field("field7").as_(list[int]) assert value == [1, 3, 2] def test_as_list_raises_error_when_field_is_not_a_valid_list(self) -> None: with pytest.raises( - AssetCardError, - match=r"The value of the field 'field7' of the asset card 'test-card' cannot be retrieved as `typing\.List\[str\]`. See nested exception for details\.$", + AssetCardError, match=r"The value of the 'field7' field of the 'test-card' asset card cannot be parsed as `list\[str\]`. See the nested exception for details\.$", # fmt: skip ): - self.card.field("field7").as_(List[str]) + self.card.field("field7").as_(list[str]) def test_as_dict_works(self) -> None: - value = self.card.field("field2").as_(Dict[str, str]) + value = self.card.field("field2").as_(dict[str, str]) assert value == {"sub-field2": "sub-foo2"} def test_as_dict_raises_error_when_field_is_not_a_valid_dict(self) -> None: with pytest.raises( - AssetCardError, - match=r"The value of the field 'field2' of the asset card 'test-card' cannot be retrieved as `typing\.Dict\[str, int\]`. See nested exception for details\.$", + AssetCardError, match=r"The value of the 'field2' field of the 'test-card' asset card cannot be parsed as `dict\[str, int\]`. See the nested exception for details\.$", # fmt: skip ): - self.card.field("field2").as_(Dict[str, int]) + self.card.field("field2").as_(dict[str, int]) def test_as_set_works(self) -> None: - value = self.card.field("field7").as_(Set[int]) + value = self.card.field("field7").as_(set[int]) assert value == {1, 2, 3} def test_as_set_raises_error_when_field_is_not_a_valid_set(self) -> None: with pytest.raises( - AssetCardError, - match=r"The value of the field 'field7' of the asset card 'test-card' cannot be retrieved as `typing\.Set\[str\]`. See nested exception for details\.$", + AssetCardError, match=r"The value of the 'field7' field of the 'test-card' asset card cannot be parsed as `set\[str\]`. See the nested exception for details\.$", # fmt: skip ): - self.card.field("field7").as_(Set[str]) + self.card.field("field7").as_(set[str]) def test_as_one_of_works(self) -> None: value = self.card.field("field1").as_one_of({"foo2", "foo1"}) @@ -160,8 +151,7 @@ def test_as_one_of_works(self) -> None: def test_as_one_of_raises_error_when_field_is_not_one_of_valid_values(self) -> None: with pytest.raises( - AssetCardError, - match=rf"The value of the field 'field1' of the asset card 'test-card' must be one of \{['foo2', 'foo3']}, but is 'foo1' instead\.$", + AssetCardError, match=r"The value of the 'field1' field of the 'test-card' asset card is expected to be one of the following values, but is 'foo1' instead: foo2, foo3$", # fmt: skip ): self.card.field("field1").as_one_of({"foo3", "foo2"}) @@ -184,15 +174,13 @@ def test_as_uri_works(self) -> None: def test_as_uri_raises_error_when_field_type_is_incorrect(self) -> None: with pytest.raises( - AssetCardError, - match=rf"The value of the field 'field3' of the asset card 'test-card' cannot be retrieved as `{str}`. See nested exception for details\.$", + AssetCardError, match=rf"The value of the 'field3' field of the 'test-card' asset card cannot be parsed as `{str}`. See the nested exception for details\.$", # fmt: skip ): self.card.field("field3").as_uri() def test_as_uri_raises_error_when_field_is_not_uri_or_absolute_path(self) -> None: with pytest.raises( - AssetCardError, - match=r"The value of the field 'field1' of the asset card 'test-card' must be a URI or an absolute pathname, but is 'foo1' instead\.$", + AssetCardError, match=r"The value of the 'field1' field of the 'test-card' asset card is expected to be a URI or an absolute pathname, but is 'foo1' instead\.$", # fmt: skip ): self.card.field("field1").as_uri() @@ -203,8 +191,7 @@ def test_as_filename_works(self) -> None: def test_as_filename_raises_error_when_field_is_not_filename(self) -> None: with pytest.raises( - AssetCardError, - match=r"^The value of the field 'field6' of the asset card 'test-card' must be a filename, but is 'invalid/filename' instead\.$", + AssetCardError, match=r"^The value of the 'field6' field of the 'test-card' asset card is expected to be a filename, but is 'invalid/filename' instead\.$", # fmt: skip ): self.card.field("field6").as_filename() @@ -223,7 +210,27 @@ def test_set_works(self) -> None: def test_set_raises_error_when_path_conflicts(self) -> None: with pytest.raises( - AssetCardError, - match=r"^The asset card 'test-card' cannot have a field named 'field1.field2' due to path conflict at 'field1'\.$", + AssetCardError, match=r"^The 'test-card' asset card cannot have a field named 'field1.field2' due to path conflict at 'field1'\.$", # fmt: skip ): self.card.field("field1").field("field2").set("foo") + + def test_flatten_works(self) -> None: + card = self.card.flatten() + + expected_metadata = { + "name": "test-card", + "field1": "foo1", + "field2": { + "sub-field2": "sub-foo2", + }, + "field3": 3, + "field4": "", + "field5": [], + "field6": "invalid/filename", + "field7": [1, 3, 2], + "field8": [1, "b", 3], + "field9": "http://foo.com", + "field10": None, + } + + assert card.metadata == expected_metadata diff --git a/tests/unit/data/audio/test_waveform_to_fbank_converter.py b/tests/unit/data/audio/test_waveform_to_fbank_converter.py index a5953f8f7..b3bd1b88f 100644 --- a/tests/unit/data/audio/test_waveform_to_fbank_converter.py +++ b/tests/unit/data/audio/test_waveform_to_fbank_converter.py @@ -6,8 +6,9 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path -from typing import Final, Sequence +from typing import Final import pytest import torch diff --git a/tests/unit/data/data_pipeline/test_dynamic_bucket.py b/tests/unit/data/data_pipeline/test_dynamic_bucket.py index 096749ae7..69ffd156b 100644 --- a/tests/unit/data/data_pipeline/test_dynamic_bucket.py +++ b/tests/unit/data/data_pipeline/test_dynamic_bucket.py @@ -220,6 +220,36 @@ def test_op_works_with_min_and_drop_set(self, drop: bool) -> None: pipeline.reset() + def test_op_works_with_bucket_creation_fn_set(self) -> None: + seq = list(range(1, 7)) + + threshold = 6 + cost_fn = lambda x: x + bucket_creation_fn = lambda l: ([l[:-1]], [l[-1]]) + + pipeline = ( + read_sequence(seq) + .dynamic_bucket( + threshold, + cost_fn, + bucket_creation_fn=bucket_creation_fn, + ) + .and_return() + ) + + for _ in range(2): + it = iter(pipeline) + + assert next(it) == [1, 2] + assert next(it) == [3, 4] + assert next(it) == [5] + assert next(it) == [6] + + with pytest.raises(StopIteration): + next(it) + + pipeline.reset() + def test_op_raises_error_when_threshold_is_nonpositive(self) -> None: with pytest.raises( ValueError, match=r"^`threshold` must be greater than zero\.$" diff --git a/tests/unit/data/data_pipeline/test_read_iterator.py b/tests/unit/data/data_pipeline/test_read_iterator.py index 853ca9a6e..4a4f6c42b 100644 --- a/tests/unit/data/data_pipeline/test_read_iterator.py +++ b/tests/unit/data/data_pipeline/test_read_iterator.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterator, TypeVar +from collections.abc import Iterator +from typing import TypeVar import pytest from typing_extensions import Self diff --git a/tests/unit/data/data_pipeline/test_yield_from.py b/tests/unit/data/data_pipeline/test_yield_from.py index fd7bd02b6..e9b723722 100644 --- a/tests/unit/data/data_pipeline/test_yield_from.py +++ b/tests/unit/data/data_pipeline/test_yield_from.py @@ -6,8 +6,6 @@ from __future__ import annotations -from typing import Tuple - import pytest from fairseq2.data import DataPipeline, DataPipelineError, read_sequence @@ -15,7 +13,7 @@ class TestYieldFromOp: def test_op_works(self) -> None: - def fn(d: Tuple[int, int]) -> DataPipeline: + def fn(d: tuple[int, int]) -> DataPipeline: a, b = d seq = list(range(a, b)) @@ -42,7 +40,7 @@ def fn(d: int) -> DataPipeline: next(iter(pipeline)) def test_op_saves_and_restores_its_state(self) -> None: - def fn(d: Tuple[int, int]) -> DataPipeline: + def fn(d: tuple[int, int]) -> DataPipeline: a, b = d seq = list(range(a, b)) diff --git a/tests/unit/data/test_file_mapper.py b/tests/unit/data/test_file_mapper.py index e31db7279..2d73a21db 100644 --- a/tests/unit/data/test_file_mapper.py +++ b/tests/unit/data/test_file_mapper.py @@ -7,7 +7,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Final, Optional +from typing import Any, Final import pytest @@ -71,7 +71,7 @@ def assert_file( output: FileMapperOutput, pathname: Path, offset: int = 0, - size: Optional[int] = None, + size: int | None = None, ) -> None: data = output["data"] diff --git a/tests/unit/data/test_memory.py b/tests/unit/data/test_memory.py index 09712ee05..ad1c3b503 100644 --- a/tests/unit/data/test_memory.py +++ b/tests/unit/data/test_memory.py @@ -47,9 +47,9 @@ def test_init_works_when_input_buffer_is_shared_and_is_of_type_float(self) -> No assert view.shape == (12,) assert view.strides == (1,) - view = view.cast("f") + float_view = view.cast("f") - assert view.tolist() == pytest.approx([0.2, 0.4, 0.6]) + assert float_view.tolist() == pytest.approx([0.2, 0.4, 0.6]) def test_init_works_when_copy_is_true(self) -> None: arr = array("B", [0, 1, 2, 3]) diff --git a/tests/unit/data/test_read_pickle_wrapped_iterator.py b/tests/unit/data/test_read_pickle_wrapped_iterator.py new file mode 100644 index 000000000..a327e9e31 --- /dev/null +++ b/tests/unit/data/test_read_pickle_wrapped_iterator.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterator + +import pytest + +from fairseq2.data import read_iterator +from fairseq2.data.utils import read_pickle_wrapped_iterator + + +def example_generator() -> Iterator[int]: + for i in range(10): + yield i + + +class TestReadAndPickleWrapIterator: + def test_read_and_pickle_wrap_iterator_works(self) -> None: + with pytest.raises(TypeError): + read_iterator( + example_generator(), + reset_fn=lambda x: example_generator(), + infinite=False, + ).and_return() + + pipeline = read_pickle_wrapped_iterator(example_generator).and_return() + + it = iter(pipeline) + + assert next(it) == 0 + assert next(it) == 1 + + state = pipeline.state_dict() + + assert next(it) == 2 + assert next(it) == 3 + assert next(it) == 4 + + pipeline.load_state_dict(state) + + assert next(it) == 2 + assert next(it) == 3 + assert next(it) == 4 + + pipeline.reset() + + for _ in range(2): + assert list(pipeline) == [*range(10)] + pipeline.reset() diff --git a/tests/unit/data/text/test_sentencepiece.py b/tests/unit/data/text/test_sentencepiece.py index 4354e476a..ac4faafed 100644 --- a/tests/unit/data/text/test_sentencepiece.py +++ b/tests/unit/data/text/test_sentencepiece.py @@ -7,13 +7,14 @@ from __future__ import annotations import pickle +from collections.abc import Sequence from pathlib import Path -from typing import ClassVar, Final, List, Optional, Sequence +from typing import ClassVar, Final import pytest import torch -from fairseq2.data.text import ( +from fairseq2.data.text.tokenizers.sentencepiece import ( SentencePieceDecoder, SentencePieceEncoder, SentencePieceModel, @@ -26,7 +27,7 @@ class TestSentencePieceModel: text: ClassVar[str] - token_indices: ClassVar[List[int]] + token_indices: ClassVar[list[int]] @classmethod def setup_class(cls) -> None: @@ -260,7 +261,7 @@ def test_pickle_works(self) -> None: @staticmethod def build_model( - control_symbols: Optional[Sequence[str]] = None, + control_symbols: Sequence[str] | None = None, ) -> SentencePieceModel: symbols = ["@0"] diff --git a/tests/unit/data/text/test_str_splitter.py b/tests/unit/data/text/test_str_splitter.py index c88c9a500..f683bf56d 100644 --- a/tests/unit/data/text/test_str_splitter.py +++ b/tests/unit/data/text/test_str_splitter.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence import pytest @@ -58,8 +58,16 @@ def test_call_works_when_names_are_specified(self) -> None: assert splitter(s) == {"a": "1", "b": "2", "c": "3"} + @pytest.mark.parametrize("indices", [0, 1, 4]) + def test_call_works_when_single_index_is_specified(self, indices: int) -> None: + s = "0,1,2,3,4" + + splitter = StrSplitter(sep=",", indices=indices) + + assert splitter(s) == str(indices) + @pytest.mark.parametrize("indices", [[0], [1], [4], [2, 3], [1, 2, 4]]) - def test_call_works_when_indices_are_specified( + def test_call_works_when_multiple_indices_are_specified( self, indices: Sequence[int] ) -> None: s = "0,1,2,3,4" @@ -72,9 +80,12 @@ def test_call_works_when_indices_are_specified( "indices,expected", [ ([0], [1, 2, 3, 4]), + (0, [1, 2, 3, 4]), ([4], [0, 1, 2, 3]), + (4, [0, 1, 2, 3]), ([2, 3], [0, 1, 4]), ([1, 2, 4], [0, 3]), + ([1, 2, 3, 4], [0]), ], ) def test_call_works_when_exclude_indices_are_specified( diff --git a/tests/unit/data/text/test_str_to_tensor_converter.py b/tests/unit/data/text/test_str_to_tensor_converter.py index efc4d2575..22cec94d0 100644 --- a/tests/unit/data/text/test_str_to_tensor_converter.py +++ b/tests/unit/data/text/test_str_to_tensor_converter.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any import pytest import torch @@ -22,7 +23,7 @@ def test_init_raises_error_when_data_type_is_not_supported(self) -> None: StrToTensorConverter(dtype=torch.half) @pytest.mark.parametrize("dtype", [None, torch.int16, torch.int32, torch.int64]) - def test_call_works(self, dtype: Optional[DataType]) -> None: + def test_call_works(self, dtype: DataType | None) -> None: s = "23 9 12 34 90 1 " converter = StrToTensorConverter(dtype=dtype) diff --git a/tests/unit/nn/test_position_encoder.py b/tests/unit/nn/test_position_encoder.py index 4db2f1422..f77b13430 100644 --- a/tests/unit/nn/test_position_encoder.py +++ b/tests/unit/nn/test_position_encoder.py @@ -14,6 +14,8 @@ IncrementalStateBag, LearnedPositionEncoder, RotaryEncoder, + Sinusoidal2dPositionEncoder, + Sinusoidal3dPositionEncoder, SinusoidalPositionEncoder, ) from fairseq2.utils.rng import temporary_manual_seed @@ -172,15 +174,44 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None: assert y.shape == (5, 2, 32) + def test_forward_works_with_padding_mask(self) -> None: + m = SinusoidalPositionEncoder( + encoding_dim=4, max_seq_len=10, _legacy_pad_idx=-1, device=device + ) + + x = torch.randn((3, 9, 4), device=device) + padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device) + padding_mask[:, 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (3, 9, 4) + + def test_forward_works_with_multiple_batch_dims_and_padding(self) -> None: + m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((4, 3, 9, 4), device=device) + padding_mask = torch.zeros((4, 3, 9), dtype=torch.bool, device=device) + padding_mask[..., 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (4, 3, 9, 4) + + def test_extra_repr_works(self) -> None: + m = SinusoidalPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + assert m.extra_repr() == "encoding_dim=4, max_seq_len=10" + class TestLearnedPositionEncoder: def test_init_works(self) -> None: - with temporary_manual_seed([device], seed=2): + with temporary_manual_seed(2, device): m = LearnedPositionEncoder(encoding_dim=32, max_seq_len=10, device=device) assert m.weight.dtype == torch.float32 - with temporary_manual_seed([device], seed=2): + with temporary_manual_seed(2, device): expected_weight = torch.randn(10, 32, device=device) assert_close(m.weight, expected_weight) @@ -249,6 +280,28 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None: assert y.shape == (5, 2, 32) + def test_forward_works_with_padding_mask(self) -> None: + m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((3, 9, 4), device=device) + padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device) + padding_mask[:, 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (3, 9, 4) + + def test_forward_works_with_multiple_batch_dims_and_padding(self) -> None: + m = LearnedPositionEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((4, 3, 9, 4), device=device) + padding_mask = torch.zeros((4, 3, 9), dtype=torch.bool, device=device) + padding_mask[..., 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (4, 3, 9, 4) + class TestRotaryEncoder: def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: @@ -337,3 +390,134 @@ def test_forward_works_when_state_bag_is_not_none_in_training(self) -> None: y = m(x, padding_mask=None, state_bag=state_bag) assert y.shape == (5, 2, 32) + + def test_forward_works_with_padding_mask(self) -> None: + m = RotaryEncoder(encoding_dim=4, max_seq_len=10, device=device) + + x = torch.randn((3, 9, 4), device=device) + padding_mask = torch.zeros((3, 9), dtype=torch.bool, device=device) + padding_mask[:, 5:] = True + + y = m(x, padding_mask=padding_mask) + + assert y.shape == (3, 9, 4) + + def test_forward_works_with_custom_freqs_init(self) -> None: + def custom_freqs_init(encoder: RotaryEncoder) -> torch.Tensor: + return torch.ones(encoder.encoding_dim // 2, device=device) + + m = RotaryEncoder( + encoding_dim=4, + max_seq_len=10, + freqs_init_fn=custom_freqs_init, + device=device, + ) + + x = torch.randn((3, 9, 4), device=device) + y = m(x, padding_mask=None) + + assert y.shape == (3, 9, 4) + + +class TestSinusoidal2dPositionEncoder: + def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: + with pytest.raises( + ValueError, match=r"^`encoding_dim` must be even, but is 13 instead\.$" + ): + Sinusoidal2dPositionEncoder( + encoding_dim=13, grid_dims=(10, 10), device=device + ) + + def test_forward_works(self) -> None: + m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device) + + # Test with same dimensions as grid + x = torch.randn((2, 8, 8, 4), device=device) + y = m(x) + + assert y.shape == (2, 8, 8, 4) + assert y.dtype == x.dtype + + # Test with different dimensions (should trigger interpolation) + x = torch.randn((2, 16, 16, 4), device=device) + y = m(x) + + assert y.shape == (2, 16, 16, 4) + assert y.dtype == x.dtype + + def test_forward_raises_error_on_wrong_dims(self) -> None: + m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device) + + # Test with wrong number of dimensions + x = torch.randn((2, 8, 4), device=device) + + with pytest.raises( + ValueError, + match=r"^`x` must be 4 dimensional, but is 3 dimensional instead\.$", + ): + m(x) + + def test_extra_repr_works(self) -> None: + m = Sinusoidal2dPositionEncoder(encoding_dim=4, grid_dims=(8, 8), device=device) + + assert m.extra_repr() == "encoding_dim=4, grid_dims=(8, 8)" + + +class TestSinusoidal3dPositionEncoder: + def test_init_raises_error_when_encoding_dim_is_odd(self) -> None: + with pytest.raises( + ValueError, match=r"^`encoding_dim` must be even, but is 13 instead\.$" + ): + Sinusoidal3dPositionEncoder( + encoding_dim=13, grid_dims=(8, 8, 8), device=device + ) + + def test_forward_works(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), device=device + ) + + # Test with same dimensions as grid + x = torch.randn((2, 4, 4, 4, 6), device=device) + y = m(x) + + assert y.shape == (2, 4, 4, 4, 6) + assert y.dtype == x.dtype + + # Test with different dimensions (should trigger interpolation) + x = torch.randn((2, 8, 8, 8, 6), device=device) + y = m(x) + + assert y.shape == (2, 8, 8, 8, 6) + assert y.dtype == x.dtype + + def test_forward_works_with_uniform_power(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), uniform_power=True, device=device + ) + + x = torch.randn((2, 4, 4, 4, 6), device=device) + y = m(x) + + assert y.shape == (2, 4, 4, 4, 6) + + def test_forward_raises_error_on_wrong_dims(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), device=device + ) + + # Test with wrong number of dimensions + x = torch.randn((2, 4, 4, 6), device=device) + + with pytest.raises( + ValueError, + match=r"^`x` must be 5 dimensional, but is 4 dimensional instead\.$", + ): + m(x) + + def test_extra_repr_works(self) -> None: + m = Sinusoidal3dPositionEncoder( + encoding_dim=6, grid_dims=(4, 4, 4), device=device + ) + + assert m.extra_repr() == "encoding_dim=6, grid_dims=(4, 4, 4)" diff --git a/tests/unit/nn/transformer/test_attention.py b/tests/unit/nn/transformer/test_attention.py index fdc23d05e..484c0cc28 100644 --- a/tests/unit/nn/transformer/test_attention.py +++ b/tests/unit/nn/transformer/test_attention.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any import pytest import torch @@ -51,7 +51,7 @@ def test_torch_sdpa( @staticmethod def _get_sdpa_args( use_key_padding_mask: bool, use_attn_mask: bool - ) -> Dict[str, Any]: + ) -> dict[str, Any]: batch_size = 2 num_heads = 4 @@ -102,9 +102,7 @@ class TestStandardMultiheadAttention: (256, None), # same size, by default ], ) - def test_variable_sized_attention( - self, model_dim: int, kv_dim: Optional[int] - ) -> None: + def test_variable_sized_attention(self, model_dim: int, kv_dim: int | None) -> None: """ Testing that attention can work when the keys and values have a different size than queries. This may happen in encoder-decoder attention, if the encoder and decoder have different dimensions. diff --git a/tests/unit/nn/utils/test_gradient.py b/tests/unit/nn/utils/test_gradient.py index 89616bd8e..db8b6af22 100644 --- a/tests/unit/nn/utils/test_gradient.py +++ b/tests/unit/nn/utils/test_gradient.py @@ -32,6 +32,6 @@ def test_scale_gradient_raises_error_if_tensor_is_non_float() -> None: with pytest.raises( TypeError, - match=r"^`x` must be a float tensor, but is of type `torch\.int32` instead\.$", + match=r"^`x` must be a float tensor, but is a `torch\.int32` tensor instead\.$", ): scale_gradient(a, 1.0) diff --git a/tests/unit/optim/test_adamw.py b/tests/unit/optim/test_adamw.py index 18be90893..c49f34a01 100644 --- a/tests/unit/optim/test_adamw.py +++ b/tests/unit/optim/test_adamw.py @@ -6,8 +6,6 @@ from __future__ import annotations -from typing import Tuple - import pytest import torch from torch import Tensor @@ -57,11 +55,11 @@ def test_step_updates_fp16_params_correctly(self) -> None: if not torch.isnan(p2).any() and not torch.isinf(p2).any(): assert_close(p1, p2) - def run_step(self, dtype: DataType) -> Tuple[Module, Module]: - with temporary_manual_seed([device], seed=2): + def run_step(self, dtype: DataType) -> tuple[Module, Module]: + with temporary_manual_seed(2, device): net1 = AdamWTestNet(dtype) - with temporary_manual_seed([device], seed=2): + with temporary_manual_seed(2, device): net2 = AdamWTestNet(dtype) opt1 = AdamW( diff --git a/tests/unit/optim/test_lr_scheduler.py b/tests/unit/optim/test_lr_scheduler.py index 4b0dd6b13..3c81f1b22 100644 --- a/tests/unit/optim/test_lr_scheduler.py +++ b/tests/unit/optim/test_lr_scheduler.py @@ -7,7 +7,7 @@ from __future__ import annotations import math -from typing import Sequence, Union +from collections.abc import Sequence import pytest from torch import Tensor @@ -16,12 +16,15 @@ from torch.optim import SGD from fairseq2.optim.lr_scheduler import ( + COSINE_ANNEALING_LR, CosineAnnealingLR, + CosineAnnealingLRConfig, LRScheduler, MyleLR, NoamLR, PolynomialDecayLR, TriStageLR, + create_lr_scheduler, ) @@ -42,6 +45,7 @@ def setup_method(self) -> None: self.base_lr2 = 0.5 self.net = LRSchedulerTestNet() + self.opt = SGD( params=[ # type: ignore[arg-type] {"params": self.net.conv1.parameters()}, @@ -218,7 +222,7 @@ def test_cosine_with_no_cycle_scale(self) -> None: self.step(scheduler) @pytest.mark.parametrize("start_lr", [0.0, (0.0, 0.0), [0.02, 0.2]]) - def test_myle(self, start_lr: Union[float, Sequence[float]]) -> None: + def test_myle(self, start_lr: float | Sequence[float]) -> None: if isinstance(start_lr, float): start_lr1 = start_lr start_lr2 = start_lr @@ -544,3 +548,46 @@ def test_tristage(self) -> None: assert lr1 == pytest.approx(final_lr1) assert lr2 == pytest.approx(final_lr2) + + +class TestLRSchedulerFactory: + def setup_method(self) -> None: + self.net = LRSchedulerTestNet() + + self.opt = SGD( + params=[ # type: ignore[arg-type] + {"params": self.net.conv1.parameters()}, + {"params": self.net.conv2.parameters(), "lr": 0.5}, + ], + lr=0.05, + ) + + def test_cosine_annealing_lr_raises_error_when_both_final_lr_and_scale_are_specified( + self, + ) -> None: + config = CosineAnnealingLRConfig(final_lr=0.02, final_lr_scale=0.02) + + with pytest.raises( + ValueError, match=r"^`config.final_lr` and `config.final_lr_scale` must not be specified at the same time\.$" # fmt: skip + ): + create_lr_scheduler(COSINE_ANNEALING_LR, self.opt, config, max_num_steps=10) + + def test_cosine_annealing_lr_raises_error_when_both_final_lr_and_scale_are_none( + self, + ) -> None: + config = CosineAnnealingLRConfig(final_lr=None, final_lr_scale=None) + + with pytest.raises( + ValueError, match=r"^Either `config.final_lr` or `config.final_lr_scale` must be specified." # fmt: skip + ): + create_lr_scheduler(COSINE_ANNEALING_LR, self.opt, config, max_num_steps=10) + + def test_cosine_annealing_lr_raises_error_when_cycle_len_is_not_specified( + self, + ) -> None: + config = CosineAnnealingLRConfig(cycle_len=None) + + with pytest.raises( + ValueError, match=r"^`config.cycle_len` must be specified when `num_steps` is not specified\.$" # fmt: skip + ): + create_lr_scheduler(COSINE_ANNEALING_LR, self.opt, config) diff --git a/tests/unit/recipes/__init__.py b/tests/unit/recipes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/recipes/utils/__init__.py b/tests/unit/recipes/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/recipes/utils/test_sweep_tagger.py b/tests/unit/recipes/utils/test_sweep_tagger.py new file mode 100644 index 000000000..9630b8d9e --- /dev/null +++ b/tests/unit/recipes/utils/test_sweep_tagger.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Hashable + +import pytest + +from fairseq2.recipes.utils.sweep_tagger import ( + StandardSweepTagger, + SweepFormatPlaceholderError, +) + + +class TestStandardSweepTagger: + def test_call_works(self) -> None: + config = { + "foo1": "a", + "foo2": {"foo1": 0.2}, + "foo3": True, + "foo4": 1, + "foo5": 2.0, + "foo6": [1, 2, 3], + } + + tagger = self._create_tagger() + + tag = tagger.generate(2, "foo", config) + + assert tag == "ps_foo.ws_2.a618ea54" + + def test_call_works_when_key_order_is_different(self) -> None: + config = { + "foo1": "a", + "foo5": 2.0, + "foo2": {"foo1": 0.2}, + "foo4": 1, + "foo3": True, + "foo6": [1, 2, 3], + } + + tagger = self._create_tagger() + + tag = tagger.generate(2, "foo", config) + + assert tag == "ps_foo.ws_2.a618ea54" + + def test_call_works_when_keys_are_disallowed(self) -> None: + config = { + "foo1": "a", + "foo2": {"foo1": 0.2}, + "foo3": True, + "foo4": 1, + "foo5": 2.0, + "foo6": [1, 2, 3, {"foo7": "a"}], + "foo8": "b", # should be ignored. + "foo9": "c", # should be ignored. + } + + tagger = self._create_tagger() + + tag = tagger.generate(2, "foo", config) + + assert tag == "ps_foo.ws_2.a618ea54" + + def test_call_works_when_sweep_format_is_specified(self) -> None: + fmt = "ps_{preset}.{{foo9}}.foo5_{{{foo5}}}.foo21_{foo2.foo1}.foo61_{foo6[1]}.{hash}" + + config = { + "foo1": "a", + "foo5": 2.0, + "foo2": {"foo1": 0.2}, + "foo4": 1, + "foo3": True, + "foo6": [1, 2, 3], + } + + tagger = self._create_tagger() + + tag = tagger.generate(2, "foo", config, fmt=fmt) + + assert tag == "ps_foo.{foo9}.foo5_{2.0}.foo21_0.2.foo61_2.a618ea54" + + def test_call_raises_error_when_sweep_format_is_invalid(self) -> None: + fmt = "foo_{foo1" + + config = {"foo1": "a"} + + tagger = self._create_tagger() + + with pytest.raises( + ValueError, match=r"^`fmt` must have matching opening and closing braces.$" # fmt: skip + ): + tagger.generate(2, "foo", config, fmt=fmt) + + def test_call_raises_error_when_sweep_format_has_unknown_key(self) -> None: + fmt = "foo_{foo2}.foo_{foo3}.foo_{foo2}" + + config = {"foo1": "a"} + + tagger = self._create_tagger() + + with pytest.raises( + SweepFormatPlaceholderError, match=r"^`fmt` must contain only placeholders that correspond to the configuration keys, but contains the following unexpected placeholder\(s\): foo2, foo3$" # fmt: skip + ): + tagger.generate(2, "foo", config, fmt=fmt) + + @staticmethod + def _create_tagger() -> StandardSweepTagger: + allowed_keys: set[Hashable] = {f"foo{i}" for i in range(7)} + + return StandardSweepTagger(allowed_keys) diff --git a/tests/unit/test_config_registry.py b/tests/unit/test_config_registry.py index 28ed52c6e..8a3acfd4d 100644 --- a/tests/unit/test_config_registry.py +++ b/tests/unit/test_config_registry.py @@ -11,6 +11,7 @@ import pytest from fairseq2.config_registry import ConfigRegistry +from fairseq2.error import AlreadyExistsError @dataclass @@ -49,8 +50,7 @@ def test_register_raises_error_when_name_is_already_registered( registry.register("name", lambda: Foo("config")) with pytest.raises( - ValueError, - match=r"^`name` must be a unique configuration name, but 'name' has already a registered configuration factory\.$", + AlreadyExistsError, match=r"^The registry has already a configuration named 'name'\.$", # fmt: skip ): registry.register("name", lambda: Foo("config")) @@ -58,7 +58,6 @@ def test_get_raises_error_when_name_is_not_registered(self) -> None: registry = ConfigRegistry[Foo]() with pytest.raises( - ValueError, - match=r"^`name` must be a registered configuration name, but is 'foo' instead\.$", + LookupError, match=r"^'foo' is not a registered configuration name\.$", # fmt: skip ): registry.get("foo") diff --git a/tests/unit/test_gang.py b/tests/unit/test_gang.py new file mode 100644 index 000000000..9d645d3d7 --- /dev/null +++ b/tests/unit/test_gang.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import pytest +import torch +from torch import Tensor +from torch.distributed import ProcessGroup +from typing_extensions import override + +from fairseq2.gang import ( + AbstractGang, + GangError, + ReduceOperation, + _setup_2D_mesh_gangs, + setup_hybrid_fsdp_gangs, + setup_parallel_gangs, +) + + +class MockGang(AbstractGang): + """ + A mock gang that keeps track of the list of the process ranks. + """ + + _group_ranks: list[int] + + def __init__(self, group_ranks: list[int], *, rank: int = 0) -> None: + super().__init__(rank=rank, size=len(group_ranks), device=torch.device("cpu")) + self._group_ranks = list(group_ranks) + + @override + def close(self) -> None: + pass + + @override + def _do_make_gang(self, ranks: Sequence[int]) -> MockGang | None: + try: + idx = ranks.index(self._rank) + except ValueError: + return None + + return MockGang(list(ranks), rank=idx) + + @property + def group_ranks(self) -> list[int]: + return self._group_ranks + + @override + def as_process_group(self) -> ProcessGroup: + raise RuntimeError("This method should not be called for this mock gang.") + + @override + def barrier(self) -> None: + pass + + @override + def all_reduce(self, tensor: Tensor, op: ReduceOperation) -> None: + raise RuntimeError("This method should not be called for this mock gang.") + + @override + def all_gather(self, output_tensor: Tensor, input_tensor: Tensor) -> None: + raise RuntimeError("This method should not be called for this mock gang.") + + @override + def all_gather_to_list( + self, output_tensors: list[Tensor], input_tensor: Tensor + ) -> None: + raise RuntimeError("This method should not be called for this mock gang.") + + @override + def broadcast(self, tensor: Tensor, source_rank: int = 0) -> None: + raise RuntimeError("This method should not be called for this mock gang.") + + @override + def broadcast_objects(self, objects: list[Any], source_rank: int = 0) -> None: + raise RuntimeError("This method should not be called for this mock gang.") + + +class TestGang: + @pytest.mark.parametrize( + "rank,size,row_length,expected", + [ + (0, 2, 2, ([0, 1], [0])), + # mesh for 2 hosts, 2 GPUs each: + # Host 0: g0 | g1 + # Host 1: g2 | g3 + (0, 4, 2, ([0, 1], [0, 2])), + (1, 4, 2, ([0, 1], [1, 3])), + (2, 4, 2, ([2, 3], [0, 2])), + (0, 8, 4, ([0, 1, 2, 3], [0, 4])), + ], + ) + def test_setup_2D_mesh_works( + self, + rank: int, + size: int, + row_length: int, + expected: tuple[list[int], list[int]], + ) -> None: + root_gang = MockGang(list(range(size)), rank=rank) + + gangs = _setup_2D_mesh_gangs( + root_gang, + row_length=row_length, + create_single_rank_process_groups=True, + ) + + for i in range(2): + gang = gangs[i] + + # typecheck confirms that `create_single_rank_process_groups` works + assert isinstance(gang, MockGang) + + assert gang.group_ranks == expected[i] + + @pytest.mark.parametrize( + "rank,size,row_length", + [ + (0, 2, 0), + (0, 2, 3), + (0, 16, 7), + ], + ) + def test_setup_with_2D_mesh_raises_exception_on_bad_mesh( + self, rank: int, size: int, row_length: int + ) -> None: + root_gang = MockGang(list(range(size)), rank=rank) + + with pytest.raises((ValueError, GangError)): + setup_hybrid_fsdp_gangs(root_gang, row_length) + + with pytest.raises((ValueError, GangError)): + setup_parallel_gangs(root_gang, tp_size=row_length) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index de9742ada..22b7b6ec3 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -72,6 +72,6 @@ def test_load_state_dict_raises_error_when_state_dict_is_corrupt(self) -> None: with pytest.raises( ValueError, - match=r"^`state_dict` must contain metrics \['test1', 'test2'\], but contains \['foo'\] instead\.$", + match=r"^`state_dict` must contain the states of the following metric\(s\): test1, test2$", ): bag.load_state_dict(state_dict) diff --git a/tests/unit/utils/test_dataclass.py b/tests/unit/utils/test_dataclass.py index f01302b1f..fb58cb65f 100644 --- a/tests/unit/utils/test_dataclass.py +++ b/tests/unit/utils/test_dataclass.py @@ -7,70 +7,46 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional -import pytest - -from fairseq2.utils.dataclass import FieldError, update_dataclass - - -@dataclass -class Foo2: - x: Optional[int] - y: str +from fairseq2.utils.dataclass import EMPTY, merge_dataclass @dataclass class Foo1: a: int - b: str - c: Foo2 - d: Optional[Foo2] - - -class TestUpdateDataclassFunction: - def test_call_works(self) -> None: - obj = Foo1(a=1, b="b", c=Foo2(x=2, y="foo3"), d=Foo2(x=3, y="foo3")) + b: Foo2 | Foo3 + c: str - overrides = {"b": "a", "c": {"x": None, "y": "foo4"}, "d": None} - unknown_fields = update_dataclass(obj, overrides) - - assert obj == Foo1(a=1, b="a", c=Foo2(x=None, y="foo4"), d=None) - - assert unknown_fields == [] - - def test_call_works_when_overrides_is_empty(self) -> None: - obj = Foo1(a=1, b="b", c=Foo2(x=2, y="foo3"), d=Foo2(x=3, y="foo3")) - - update_dataclass(obj, {}) +@dataclass +class Foo2: + x: int - assert obj == Foo1(a=1, b="b", c=Foo2(x=2, y="foo3"), d=Foo2(x=3, y="foo3")) - unknown_fields = update_dataclass(obj, {"c": {}}) +@dataclass +class Foo3: + y: int = 1 + z: int = 2 - assert obj == Foo1(a=1, b="b", c=Foo2(x=2, y="foo3"), d=Foo2(x=3, y="foo3")) - assert unknown_fields == [] +def test_merge_dataclass() -> None: + target = Foo1(a=3, b=Foo3(y=5), c="foo") + source = Foo1(a=2, b=Foo3(y=EMPTY, z=3), c=EMPTY) # type: ignore[arg-type] - def test_call_raises_error_when_there_are_invalid_overrides(self) -> None: - obj = Foo1(a=1, b="b", c=Foo2(x=2, y=3), d=Foo2(x=3, y="foo3")) # type: ignore[arg-type] + target = merge_dataclass(target, source) - overrides = {"c": 4} # type: ignore[dict-item] + assert target == Foo1(a=2, b=Foo3(y=5, z=3), c="foo") - with pytest.raises( - FieldError, - match=rf"^The field 'c' is expected to be of type `{Foo2}`, but is of type `{int}` instead\.$", - ): - update_dataclass(obj, overrides) + target = Foo1(a=3, b=Foo3(y=1), c="foo") + source = Foo1(a=EMPTY, b=Foo3(y=2, z=EMPTY), c="foo") # type: ignore[arg-type] - def test_call_works_when_there_are_unknown_overrides(self) -> None: - obj = Foo1(a=1, b="b", c=Foo2(x=2, y="foo3"), d=Foo2(x=3, y="foo3")) + target = merge_dataclass(target, source) - overrides = {"b": "a", "c": {"y": "foo4", "z": 2}, "e": 4} + assert target == Foo1(a=3, b=Foo3(y=2, z=2), c="foo") - unknown_fields = update_dataclass(obj, overrides) + target = Foo1(a=3, b=Foo2(x=1), c="foo") + source = Foo1(a=2, b=EMPTY, c="foo") # type: ignore[arg-type] - assert obj == Foo1(a=1, b="a", c=Foo2(x=2, y="foo4"), d=Foo2(x=3, y="foo3")) + target = merge_dataclass(target, source) - assert unknown_fields == ["c.z", "e"] + assert target == Foo1(a=2, b=Foo2(x=1), c="foo") diff --git a/tests/unit/utils/test_state.py b/tests/unit/utils/test_state.py index 6807b186b..7c37dc8b5 100644 --- a/tests/unit/utils/test_state.py +++ b/tests/unit/utils/test_state.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Any, Dict, Mapping +from collections.abc import Mapping +from typing import Any import pytest @@ -18,7 +19,7 @@ class TestStatefulObjectBag: def test_state_dict_works(self) -> None: class Foo: - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"foo4": "value4"} def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: @@ -56,7 +57,7 @@ def set_state(self, stateful: Any, state: Any) -> None: with pytest.raises( ValueError, - match="^`state_dict` must only contain the states of the attributes of this object, but it contains the following extra keys: foo3", + match="^`state_dict` must contain only the states of the attributes of this object, but it contains the following unexpected keys: foo3", ): bag.load_state_dict(state_dict) diff --git a/tests/unit/utils/test_structured.py b/tests/unit/utils/test_structured.py new file mode 100644 index 000000000..ddd2d94f6 --- /dev/null +++ b/tests/unit/utils/test_structured.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path + +import pytest +import torch + +from fairseq2.typing import DataType +from fairseq2.utils.dataclass import EMPTY +from fairseq2.utils.structured import ( + StructureError, + is_unstructured, + merge_unstructured, + structure, + unstructure, +) + +# mypy: disable-error-code="arg-type" + + +class FooEnum(Enum): + VALUE1 = 1 + VALUE2 = 2 + + +@dataclass +class Foo1: + f0: str = "foo" + f1: int = 0 + f2: dict[str, Path] = field(default_factory=dict) + f3: list[int] = field(default_factory=list) + f4: Foo2 | Foo3 = field(default_factory=lambda: Foo2()) + f5: tuple[float, int] = (1.0, 2) + f6: set[int] = field(default_factory=set) + f7: FooEnum | None = None + f8: DataType = torch.float32 + f9: Foo3 | None = None + + +@dataclass +class Foo2: + f2_1: int = 1 + + +@dataclass +class Foo3: + f3_1: int = 2 + f3_2: int = 3 + + +def test_structure_works() -> None: + data = { + "f0": "abc", + "f1": "2", + "f2": {"a": "path1", "b": Path("path2")}, + "f3": [0, "1", 2, "3"], + "f4": {"f3_1": "4"}, + "f5": ["3.0", "4"], + "f6": ["1", "2", "3"], + "f7": "VALUE2", + "f9": {"f3_2": "4"}, + } + + foo = structure(data, Foo1) + + expected_foo = Foo1( + f0="abc", + f1=2, + f2={"a": Path("path1"), "b": Path("path2")}, + f3=[0, 1, 2, 3], + f4=Foo3(f3_1=4, f3_2=3), + f5=(3.0, 4), + f6={1, 2, 3}, + f7=FooEnum.VALUE2, + f8=torch.float32, + f9=Foo3(f3_1=2, f3_2=4), + ) + + assert foo == expected_foo + + +def test_structure_works_when_set_empty_is_true() -> None: + data = { + "f0": "abc", + "f1": "2", + "f2": {"a": "path1", "b": Path("path2")}, + "f3": [0, "1", 2, "3"], + "f4": {"f3_1": "4"}, + "f5": ["3.0", "4"], + "f6": ["1", "2", "3"], + "f7": "VALUE2", + "f9": {"f3_2": "4"}, + } + + foo = structure(data, Foo1, set_empty=True) + + expected_foo = Foo1( + f0="abc", + f1=2, + f2={"a": Path("path1"), "b": Path("path2")}, + f3=[0, 1, 2, 3], + f4=Foo3(f3_1=4, f3_2=EMPTY), + f5=(3.0, 4), + f6={1, 2, 3}, + f7=FooEnum.VALUE2, + f8=EMPTY, + f9=Foo3(f3_1=EMPTY, f3_2=4), + ) + + assert foo == expected_foo + + +@pytest.mark.parametrize( + "data,kls", + [ + ("a", int), + ({"a": 1}, dict), + ("a", list), + ("a", FooEnum), + ({"f1_1": 2, "f1_2": 3}, Foo2), + ], +) +def test_structure_raises_error_when_conversion_fails(data: object, kls: type) -> None: + with pytest.raises( + StructureError, match=rf"^`obj` cannot be structured to `{kls}`\. See the nested exception for details\.$" # fmt: skip + ): + structure(data, kls) + + +def test_unstructure_works() -> None: + foo = Foo1( + f0="abc", + f1=2, + f2={"a": Path("path1"), "b": Path("path2")}, + f3=[0, 1, 2, 3], + f4=Foo3(f3_1=4), + f5=(3.0, 4), + f6={1, 2, 3}, + f7=FooEnum.VALUE2, + f8=torch.float16, + f9=Foo3(f3_1=1), + ) + + data = unstructure(foo) + + expected_data = { + "f0": "abc", + "f1": 2, + "f2": {"a": "path1", "b": "path2"}, + "f3": [0, 1, 2, 3], + "f4": {"f3_1": 4, "f3_2": 3}, + "f5": [3.0, 4], + "f6": [1, 2, 3], + "f7": "VALUE2", + "f8": "float16", + "f9": {"f3_1": 1, "f3_2": 3}, + } + + assert data == expected_data + + +def test_is_unstructured_works_when_object_is_unstructed() -> None: + obj = { + "foo1": True, + "foo2": 1, + "foo3": 1.0, + "foo4": "a", + "foo5": { + "foo6": "x", + }, + "foo7": [1, False, 3.0, "a"], + "foo8": None, + } + + assert is_unstructured(obj) + + +def test_is_unstructured_works_when_object_is_structed() -> None: + obj = object() + + assert not is_unstructured(obj) + + obj = { + "foo1": True, + "foo2": object(), + "foo3": "a", + } + + assert not is_unstructured(obj) + + +def test_merge_unstructured_works() -> None: + target = { + "foo1": "abc", + "foo2": { + "foo2_foo1": 4, + "foo2_foo2": { + "foo2_foo2_foo1": "x", + }, + "foo2_foo3": 4, + }, + "foo3": True, + "foo4": { + "foo4_foo1": "y", + "foo4_foo2": "z", + }, + "foo5": 1.0, + "foo6": [1, 2, 3], + } + + source = { + "_del_": ["foo3"], + "_add_": { + "foo4": { + "foo4_foo3": 2, + }, + "foo7": 2.0, + "foo8": 3, + }, + "_set_": { + "foo5": 2.0, + }, + "foo2": { + "_del_": ["foo2_foo1"], + "_add_": { + "foo3_foo4": "a", + }, + "foo2_foo2": { + "foo2_foo2_foo1": "b", + }, + "foo2_foo3": 5, + }, + "foo6": [4, 5, 6], + } + + output = merge_unstructured(target, source) + + expected_output = { + "foo1": "abc", + "foo2": { + "foo2_foo2": { + "foo2_foo2_foo1": "b", + }, + "foo2_foo3": 5, + "foo3_foo4": "a", + }, + "foo4": { + "foo4_foo3": 2, + }, + "foo5": 2.0, + "foo6": [4, 5, 6], + "foo7": 2.0, + "foo8": 3, + } + + assert output == expected_output + + +def test_merge_unstructured_raises_error_when_type_is_invalid() -> None: + target: object + source: object + + target = object() + source = None + + with pytest.raises( + StructureError, match=r"^`target` must be of a composition of types `bool`, `int`, `float`, `str`, `list`, and `dict`\.$" # fmt: skip + ): + merge_unstructured(target, source) + + target = None + source = object() + + with pytest.raises( + StructureError, match=r"^`source` must be of a composition of types `bool`, `int`, `float`, `str`, `list`, and `dict`\.$" # fmt: skip + ): + merge_unstructured(target, source) + + target = {} + source = {"_del_": "foo"} + + with pytest.raises( + StructureError, match=r"^'_del_' in `source` must be of type `list`, but is of type `str` instead\.$" # fmt: skip + ): + merge_unstructured(target, source) + + target = {"foo1": {"foo2": {}}} + source = {"foo1": {"foo2": {"_del_": "foo"}}} + + with pytest.raises( + StructureError, match=r"^'foo1\.foo2\._del_' in `source` must be of type `list`, but is of type `str` instead\.$" # fmt: skip + ): + merge_unstructured(target, source) + + target = {"foo1": {"foo2": {}}} + source = {"foo1": {"foo2": {"_del_": [0]}}} + + with pytest.raises( + StructureError, match=r"^Each element under 'foo1\.foo2\._del_' in `source` must be of type `str`, but the element at index 0 is of type `int` instead\.$" # fmt: skip + ): + merge_unstructured(target, source) + + target = {"foo1": {"foo2": {}}} + source = {"foo1": {"foo2": {"_add_": "foo"}}} + + with pytest.raises( + StructureError, match=r"^'foo1\.foo2\._add_' in `source` must be of type `dict`, but is of type `str` instead\.$" # fmt: skip + ): + merge_unstructured(target, source) + + target = {"foo1": {"foo2": {}}} + source = {"foo1": {"foo2": {"_add_": {0: "foo"}}}} + + with pytest.raises( + StructureError, match=r"^Each key under 'foo1\.foo2\._add_' in `source` must be of type `str`, but the key at index 0 is of type `int` instead\.$" # fmt: skip + ): + merge_unstructured(target, source) + + +def test_merge_unstructured_raises_error_when_path_does_not_exist() -> None: + target: object + source: object + + target = {"foo1": 0} + source = {"foo2": 1} + + with pytest.raises( + ValueError, match=r"^`target` must contain a 'foo2' key since it exists in `source`\." # fmt: skip + ): + merge_unstructured(target, source) + + target = {"foo1": {"foo2": 0}} + source = {"foo1": {"foo3": 1}} + + with pytest.raises( + ValueError, match=r"^`target` must contain a 'foo1\.foo3' key since it exists in `source`\." # fmt: skip + ): + merge_unstructured(target, source) diff --git a/tools/set-project-version.sh b/tools/set-project-version.sh index eeb0329fd..9d1d15901 100755 --- a/tools/set-project-version.sh +++ b/tools/set-project-version.sh @@ -10,7 +10,7 @@ set -eo pipefail function print_usage { - echo "Usage: set-version PEP440_VERSION" + echo "Usage: set-version [--native-only] PEP440_VERSION" } function exit_with_usage @@ -47,12 +47,20 @@ function extract_mmm_version echo "$1" | grep --only-matching --extended-regexp '^([0-9]+\.)*[0-9]+' - } -if [[ $# -ne 1 ]]; then - exit_with_error -fi +if [[ $# -eq 1 ]]; then + if [[ $1 == -h || $1 == --help ]]; then + exit_with_usage + fi + + native_only=false +elif [[ $# -eq 2 ]]; then + if [[ $1 == "--native-only" ]]; then + shift -if [[ $1 == -h || $1 == --help ]]; then - exit_with_usage + native_only=true + else + exit_with_error + fi fi pep_ver=$(extract_pep_version "$1") @@ -61,14 +69,16 @@ mmm_ver=$(extract_mmm_version "$1") base=$(cd "$(dirname "$0")/.." && pwd) # Update Python distribution. -replace_match\ - "$base/setup.py"\ - "s/^version = \".*\"$/version = \"$pep_ver\"/" - -# Update Python package. -replace_match\ - "$base/src/fairseq2/__init__.py"\ - "s/^__version__ = \".*\"$/__version__ = \"$pep_ver\"/" +if [[ $native_only != true ]]; then + replace_match\ + "$base/setup.py"\ + "s/^version = \".*\"$/version = \"$pep_ver\"/" + + # Update Python package. + replace_match\ + "$base/src/fairseq2/__init__.py"\ + "s/^__version__ = \".*\"$/__version__ = \"$pep_ver\"/" +fi # Update fairseq2n CMake project. replace_match\