From 317f18c9c2ec8b2e610528640c3aa6b59914f9ea Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Mon, 2 Dec 2024 17:28:14 +0800 Subject: [PATCH 01/12] fix check CUDA_DEVICE_MAX_CONNECTIONS --- internlm/utils/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 56ebcfbe..de885b72 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -250,8 +250,9 @@ def enable_pytorch_expandable_segments(): def check_cuda_env(): - if os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") is None: - logger.warning("Env var CUDA_DEVICE_MAX_CONNECTIONS has not be set, please note this!") + max_connections = os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") + assert max_connections is not None, "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!" + assert max_connections == '1', "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, but it should be set to 1!".format(max_connections) class DummyProfile: From 6b7df0bbe90803855fe2472aad50a97c31f02ac9 Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Mon, 2 Dec 2024 18:04:58 +0800 Subject: [PATCH 02/12] Revert "fix check CUDA_DEVICE_MAX_CONNECTIONS" This reverts commit 317f18c9c2ec8b2e610528640c3aa6b59914f9ea. --- internlm/utils/common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index de885b72..56ebcfbe 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -250,9 +250,8 @@ def enable_pytorch_expandable_segments(): def check_cuda_env(): - max_connections = os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") - assert max_connections is not None, "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!" - assert max_connections == '1', "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, but it should be set to 1!".format(max_connections) + if os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") is None: + logger.warning("Env var CUDA_DEVICE_MAX_CONNECTIONS has not be set, please note this!") class DummyProfile: From 902b63271851ba8eb974be118cc45c86a98311c4 Mon Sep 17 00:00:00 2001 From: kkscilife <126147887+kkscilife@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:06:11 +0800 Subject: [PATCH 03/12] test(ci): move ci to T (#388) Co-authored-by: kkscilife --- .github/workflows/e2e_test.yaml | 132 +++++++++++++++----------------- 1 file changed, 62 insertions(+), 70 deletions(-) diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml index ff992c1f..f36c0566 100644 --- a/.github/workflows/e2e_test.yaml +++ b/.github/workflows/e2e_test.yaml @@ -73,7 +73,7 @@ jobs: training_8GPU_4DP2TP: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 15 steps: @@ -81,21 +81,20 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_8GPU_4DP2TP_910B - if: ${{ matrix.runner == '910B' }} + - name: training_8GPU_4DP2TP_T + if: ${{ matrix.runner == 't_cluster' }} run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --nproc_per_node=8 --nnodes=1 -m pytest -p no:cacheprovider -v --color=yes -m "training_8GPU_4DP2TP" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU_4DP2TP" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_8GPU_4DP2TPSP: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 15 steps: @@ -103,21 +102,20 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_8GPU_4DP2TPSP_910B - if: ${{ matrix.runner == '910B' }} + - name: training_8GPU_4DP2TPSP_T + if: ${{ matrix.runner == 't_cluster' }} run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --nproc_per_node=8 --nnodes=1 -m pytest -p no:cacheprovider -v --color=yes -m "training_8GPU_4DP2TPSP" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU_4DP2TPSP" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_8GPU_4DP2PP: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 15 steps: @@ -125,16 +123,15 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_8GPU_4DP2PP_910B - if: ${{ matrix.runner == '910B' }} + - name: training_8GPU_4DP2PP_T + if: ${{ matrix.runner == 't_cluster' }} run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --nproc_per_node=8 --nnodes=1 -m pytest -p no:cacheprovider -v --color=yes -m "training_8GPU_4DP2PP" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU_4DP2PP" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_8GPU_4DP2PP_ZB: runs-on: [t_cluster] @@ -157,7 +154,7 @@ jobs: training_16GPU_4DP2TP2PP_MTP: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 15 steps: @@ -165,21 +162,20 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_16GPU_4DP2TP2PP_MTP_910B - if: ${{ matrix.runner == '910B' }} + - name: training_16GPU_4DP2TP2PP_MTP_T + if: ${{ matrix.runner == 't_cluster' }} run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=8 --nnodes=2 --node_rank=$RANK -m pytest -p no:cacheprovider -v --color=yes -m "training_16GPU_4DP2TP2PP_MTP" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" 2 "AllReduce" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_4DP2TP2PP_MTP" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_16GPU_4DP2TP2PP_MSP: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 15 steps: @@ -187,21 +183,20 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_16GPU_4DP2TP2PP_MSP_910B - if: ${{ matrix.runner == '910B' }} + - name: training_16GPU_4DP2TP2PP_MSP_T + if: ${{ matrix.runner == 't_cluster' }} run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=8 --nnodes=2 --node_rank=$RANK -m pytest -p no:cacheprovider -v --color=yes -m "training_16GPU_4DP2TP2PP_MSP" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" 2 "AllReduce" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_4DP2TP2PP_MSP" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_16GPU_4DP2TP2PP_FSP: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 15 steps: @@ -209,21 +204,20 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_16GPU_4DP2TP2PP_FSP_910B - if: ${{ matrix.runner == '910B' }} + - name: training_16GPU_4DP2TP2PP_FSP_T + if: ${{ matrix.runner == 't_cluster' }} run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=8 --nnodes=2 --node_rank=$RANK -m pytest -p no:cacheprovider -v --color=yes -m "training_16GPU_4DP2TP2PP_FSP" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" 2 "AllReduce" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_4DP2TP2PP_FSP" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_llama2: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 20 steps: @@ -231,20 +225,19 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_llama2_910B + - name: training_llama2_T run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --nproc_per_node=8 --nnodes=1 -m pytest -p no:cacheprovider -v --color=yes -m "training_llama2" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_llama2" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname training_internlm2: strategy: matrix: - runner: [910B] + runner: [t_cluster] runs-on: ${{ matrix.runner }} timeout-minutes: 20 steps: @@ -252,12 +245,11 @@ jobs: run: | echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - if [[ ${{ matrix.runner }} == 910B ]];then - sudo git clean -ffdx - fi - uses: actions/checkout@v3 - - name: training_internlm2_910B + - name: training_internlm2_T run: | - jobname=EB910-${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - start_command='torchrun --nproc_per_node=8 --nnodes=1 -m pytest -p no:cacheprovider -v --color=yes -m "training_internlm2" ./tests/test_training/test_loss.py' - bash ../910B_sco.sh $jobname "$start_command" + source activate ${evo_env_torch21_flash2} + jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_internlm2" ./tests/test_training/test_loss.py + exit_code=$? + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname From f6c66bd61ea27159389b4bc83d13e6bfb1a271a0 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:43:02 +0800 Subject: [PATCH 04/12] fix(mlp.py): swap mlp w1w2w3 init order to w1w3w2 and fix QA (#384) --- .github/workflows/e2e_test.yaml | 24 +-- configs/7B_isp_sft.py | 12 +- internlm/checkpoint/load_funcs.py | 2 + internlm/model/modeling_internlm2.py | 193 ++++++++++++++++++ internlm/model/modules/mlp.py | 18 +- internlm/model/ops/attention.py | 20 +- internlm/model/utils.py | 20 ++ tests/test_training/7B_check_acc.py | 59 +++++- tests/test_training/7B_check_init.py | 49 ++++- .../test_forward_output_no_fa.py | 11 +- tests/test_training/test_loss.py | 78 ++++--- tests/test_training/train_CI.py | 21 +- 12 files changed, 392 insertions(+), 115 deletions(-) diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml index f36c0566..28b0e4f1 100644 --- a/.github/workflows/e2e_test.yaml +++ b/.github/workflows/e2e_test.yaml @@ -1,5 +1,5 @@ name: e2e-tests -on: +on: pull_request: branches: - "develop" @@ -232,24 +232,4 @@ jobs: jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_llama2" ./tests/test_training/test_loss.py exit_code=$? - sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname - - training_internlm2: - strategy: - matrix: - runner: [t_cluster] - runs-on: ${{ matrix.runner }} - timeout-minutes: 20 - steps: - - name: mask env - run: | - echo "::add-mask::${{env.WORKSPACE_PREFIX}}" - echo "::add-mask::$path_prefix" - - uses: actions/checkout@v3 - - name: training_internlm2_T - run: | - source activate ${evo_env_torch21_flash2} - jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_internlm2" ./tests/test_training/test_loss.py - exit_code=$? - sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname + sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname \ No newline at end of file diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 2698a82f..de99f917 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -1,5 +1,5 @@ JOB_NAME = "7b_train" -# model_type = "INTERNLM2_PUBLIC" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False VOCAB_SIZE = 103168 @@ -31,7 +31,7 @@ # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined # load function such as "llama" load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering @@ -145,7 +145,7 @@ parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, - # no_bias=True, + no_bias=True, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" @@ -188,17 +188,17 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. sequence_2D (dict): 1. enable: bool, whether enable the 2D sequence parallel or not. - 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). + 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). head_size * context_size should be equal tensor size. 3. context_size: int, the parallel degree of context parallelism. head_size * context_size should be equal tensor size. 4. window_size: int, the sliding window size in context parallelism. 5. device_placement_strategy: dict, - head_first: bool, if `True`, ranks of the same head parallel group are + head_first: bool, if `True`, ranks of the same head parallel group are given high priority for colocation on the same node; if `False`, ranks of the same context parallel group are given high priority for colocation on the same node; - interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could + interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could interleaved the ranks in the same window to make full use of NIC as much as possible. """ parallel = dict( diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index d23cae63..dde4bc52 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -1,6 +1,7 @@ # Copyright (c) InternLM. All rights reserved. from internlm.model.modeling_internlm import InternLM1 +from internlm.model.modeling_internlm2 import InternLM2 from internlm.model.modeling_llama import Llama2 from internlm.utils.logger import get_logger @@ -9,4 +10,5 @@ LOAD_FUNC_DICT = { "llama": Llama2.load_llama_pretrained_weights, "internlm_test": InternLM1.load_internlm_with_dynamic_parallel_size, + "internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size, } diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index a4389b63..fedd27c4 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -1,6 +1,7 @@ # Copyright (c) InternLM. All rights reserved. import math import os +from functools import reduce from typing import Optional import torch @@ -11,6 +12,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.shard import partition_uniform from internlm.initialize.initialize_tensor import ( normal_, scaled_init_method_normal, @@ -26,6 +28,7 @@ from internlm.model.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, + get_parallel_size_from_file, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger @@ -576,6 +579,196 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: internlm_accelerator.empty_cache() + @staticmethod + def load_internlm2_with_dynamic_parallel_size(folder, model): + """Load InternLM2 with dynamic parallel size.""" + assert folder is not None, "Please specify the folder of the pretrained model" + assert gpc.config.model_type in ["INTERNLM2_PUBLIC"], "dynamic_parallel is only for INTERNLM2_PUBLIC" + + fns = get_fns(folder) + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612 + + tp = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + assert old_tp % tp == 0 or tp % old_tp == 0, ( + f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in " + f"checkpoint and {tp} in current config" + ) + + correspond_tps = [] + + if old_tp <= tp: + correspond_tps.append(tp_rank // (tp // old_tp)) + ratio = tp // old_tp + rank = tp_rank % ratio + else: + for i in range(old_tp // tp): + correspond_tps.append(tp_rank * (old_tp // tp) + i) + rank = 0 + ratio = 1 + + current_states = {} + + pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612 + + assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary" + + old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1) + + for idx, parts in enumerate(old_pp_partition): + start, end = parts[0] + if model.last_layer <= start or model.first_layer >= end: + continue + tmp_states = {} + + for correspond_tp in correspond_tps: + model_name = f"model_tp{correspond_tp}_pp{idx}.pt" + states = llm_load(os.path.join(folder, model_name), map_location="cpu") + states = {k.replace("model.", ""): v for k, v in states.items()} + for i in range(start, end): + if i >= model.last_layer: + break + if i < model.first_layer: + continue + + for name in list(states.keys()): + if f".{i-start}." in name: + to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.") + + if gpc.config.model_type == "INTERNLM2_PUBLIC": + if "norm" in name: + tmp_states[to_name] = [states.pop(name)] + elif any(x in name for x in ("wo", "w2")): + tmp_states[to_name] = tmp_states.get(to_name, []) + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank]) + elif any(x in name for x in ("w1", "w3")): + tmp_states[to_name] = tmp_states.get(to_name, []) + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) + elif any(x in name for x in ("wqkv",)): + tmp_states[to_name] = tmp_states.get(to_name, []) + if tp > gpc.config.model.num_kv_attention_heads: + assert old_tp <= gpc.config.model.num_kv_attention_heads, ( + f"`old_tp ({old_tp}) => tp ({tp})` is not supported. " + "At least one of `tp` and `old_tp` should be less than or " + "equal to `num_kv_attention_heads`" + ) + # Suitable for cases where the num_kv_attention_head is small, + # but you want to have a large TP Size + q_per_kv = ( + gpc.config.model.num_attention_heads + // gpc.config.model.num_kv_attention_heads + ) + head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads + index = torch.concat( + ( + torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio], + torch.tensor([q_per_kv, q_per_kv + 1]), + ) + ) + index = index + (q_per_kv + 2) * (tp_rank // ratio) + index = index % ( + (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp) + ) + index = index * head_dim + index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( + index.shape[0] + ) + tmp_states[to_name].append( + torch.index_select(states.pop(name), 0, index.to(torch.int32)) + ) + else: + tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) + else: + raise KeyError(f"Unknown key {name}.") + + else: + assert False, "unsupported model type" + + if "tok_embeddings.weight" in states and model.first_layer == 0: + tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", []) + tmp_states["tok_embeddings.weight"].append( + states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank] + ) + if "output.weight" in states and model.last_layer == gpc.config.model.num_layers: + tmp_states["norm.weight"] = [states["norm.weight"]] + tmp_states["output.weight"] = tmp_states.get("output.weight", []) + tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank]) + + states = {} + + for name in list(tmp_states.keys()): + data = tmp_states.pop(name) + if len(data) == 1: + current_states[name] = data[0] + else: + current_states[name] = torch.concat( + data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0 + ) + # Merge copied kv heads + if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads: + assert ( + tp <= gpc.config.model.num_kv_attention_heads + ), "new_tp should be less than or equal to num_kv_attention_heads" + head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads + q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads + copied_times = old_tp // gpc.config.model.num_kv_attention_heads + cur_q_per_kv = q_per_kv // copied_times + + # pylint: disable=all + def duplicate_kv_index(i): + if i % (cur_q_per_kv + 2) >= cur_q_per_kv: + return i + else: + return -100 + + def unique_kv_index(i): + if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv: + return i + else: + return -100 + + # pylint: enable=all + + # Verify + duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] + duplicate_index = [i for i in duplicate_index if i != -100] + duplicate_index = _duplicate_index = torch.tensor(duplicate_index) + for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): + duplicate_index = torch.concat( + (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0 + ) + duplicate_kv = [] + for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1): + index = index.reshape(-1) * head_dim + index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0]) + duplicate_kv.append(torch.index_select(current_states[name], 0, index)) + assert reduce( + lambda x, y: x and y, + [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]], + ), "Copied kv heads are not equal after training!" + + # Merge + unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] + unique_index = [i for i in unique_index if i != -100] + unique_index = _unique_index = torch.tensor(unique_index) + for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): + unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0) + unique_index = unique_index * head_dim + unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( + unique_index.shape[0] + ) + current_states[name] = torch.index_select(current_states[name], 0, unique_index) + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True): diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 6e74d6b6..e51e5897 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -99,12 +99,12 @@ def __init__( self.w1 = new_linear( "w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) - self.w2 = new_linear( - "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert - ) self.w3 = new_linear( "w3", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) + self.w2 = new_linear( + "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert + ) def forward(self, x): if not self.mlp_layer_fusion: @@ -177,10 +177,10 @@ def __init__( backend=backend, is_expert=is_expert, ) - self.w2 = new_linear( - "grouped_w2", + self.w3 = new_linear( + "grouped_w3", + in_features, hidden_features, - out_features, bias, device=device, dtype=dtype, @@ -188,10 +188,10 @@ def __init__( backend=backend, is_expert=is_expert, ) - self.w3 = new_linear( - "grouped_w3", - in_features, + self.w2 = new_linear( + "grouped_w2", hidden_features, + out_features, bias, device=device, dtype=dtype, diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index d0a668c8..c2639569 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -928,7 +928,7 @@ def _qkv_without_cu_seqlens(self, qkv, softmax_scale=None, causal=None, key_padd # TODO: more unified interface dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -944,7 +944,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -960,7 +960,7 @@ def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, ke attn_type, op = _select_attn_op(AttnOpType.FixedLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if (attn_type is AttnType.Torch and key_padding_mask is not None) else () + extra_args = (key_padding_mask,) if (attn_type is AttnType.Torch and key_padding_mask is not None) else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -984,7 +984,7 @@ def _qkv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenQKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(qkv, cu_seqlens, max_seqlen, dropout, softmax_scale, causal, *extra_args) @@ -1007,7 +1007,7 @@ def _q_kv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1033,7 +1033,7 @@ def _q_k_v_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1088,7 +1088,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(q, kv, dropout, softmax_scale, causal, *extra_args) @@ -1100,7 +1100,7 @@ def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, ke attn_type, op = _select_attn_op(AttnOpType.FixedLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(q, k, v, dropout, softmax_scale, causal, *extra_args) @@ -1123,7 +1123,7 @@ def _q_kv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1149,7 +1149,7 @@ def _q_k_v_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenQKVSplited) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args diff --git a/internlm/model/utils.py b/internlm/model/utils.py index e3ebf44d..7c974abe 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -6,9 +6,12 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.mha import MHA +from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load from internlm.utils.utils import TensorParallelMode +logger = get_logger(__file__) + def internlm1_mha_pre_load_convert( model: MHA, state_dict: Dict, prefix: str, *args, **kwargs # pylint: disable=W0613 @@ -138,3 +141,20 @@ def merge_pp_src_states(states): layer_shift += _layer_shift + 1 merged_states.append(shifted_state) return merged_states + + +def get_parallel_size_from_file(fns, suffix=None): + model_fns, old_tp, old_pp = [], -1, -1 + for fn in fns: + # filter with `_t` is for avoiding conflict with model_config.py + + if fn.startswith("model_t"): + if (suffix and fn.endswith(suffix)) or (suffix is None and not fn.endswith("md5")): + model_fns.append(fn) + _, tp, pp = os.path.splitext(fn)[0].split("_") + old_tp = max(old_tp, int(tp[2:]) + 1) + old_pp = max(old_pp, int(pp[2:]) + 1) + + assert old_tp > 0 and old_pp > 0, f"ckpt with tp:{old_tp} and pp:{old_pp} is illegal" + model_fns.sort() + return model_fns, old_tp, old_pp diff --git a/tests/test_training/7B_check_acc.py b/tests/test_training/7B_check_acc.py index 3b727d7c..cb3902bc 100644 --- a/tests/test_training/7B_check_acc.py +++ b/tests/test_training/7B_check_acc.py @@ -1,16 +1,20 @@ import os -JOB_NAME = "7b_train" +JOB_NAME = "7b_internlm2_train" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False +VOCAB_SIZE = 92544 SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 -MLP_RATIO = 8 / 3 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 3.5 NUM_LAYER = 32 -VOCAB_SIZE = 103168 -MODEL_ONLY_FOLDER = os.path.join(os.environ["share_path"], "quailty_assurance/7B_model_weights_ckpt/init") +MODEL_ONLY_FOLDER = os.path.join( + os.environ["share_path"], "quailty_assurance/7B_internlm2_init_dp=2_tp=2_pp=2_ckpt/init" +) # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' # SAVE_CKPT_FOLDER = "local:llm_ckpts_0925_9" @@ -121,7 +125,8 @@ ) model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + checkpoint=False, + num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, @@ -129,13 +134,22 @@ parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, + no_bias=True, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ zero1 parallel: @@ -150,9 +164,9 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=dict(size=8), - tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True), + zero1=dict(size=-1), + tensor=dict(size=2, mode="mtp"), + pipeline=dict(size=2, interleaved_overlap=True), weight=dict(size=1, overlap=True), ) @@ -165,5 +179,30 @@ enable_feishu_alert=DO_ALERT, feishu_alert_address=None, # feishu webhook to send alert message light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, ), ) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) + +enable_tb = False diff --git a/tests/test_training/7B_check_init.py b/tests/test_training/7B_check_init.py index 6f72c7d7..03107d02 100644 --- a/tests/test_training/7B_check_init.py +++ b/tests/test_training/7B_check_init.py @@ -1,12 +1,14 @@ -JOB_NAME = "7b_train" +JOB_NAME = "7b_internlm2_train" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False +VOCAB_SIZE = 92544 SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 -MLP_RATIO = 8 / 3 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 3.5 NUM_LAYER = 32 -VOCAB_SIZE = 103168 CHECK_INIT = 1 @@ -128,7 +130,8 @@ ) model = dict( - checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + checkpoint=False, + num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, @@ -136,13 +139,22 @@ parallel_output=True, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYER, + no_bias=True, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, - num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) parallel = dict( @@ -161,5 +173,30 @@ enable_feishu_alert=DO_ALERT, feishu_alert_address=None, # feishu webhook to send alert message light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, ), ) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) + +enable_tb = False diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index 69c9eb90..48b97bfa 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -56,7 +56,7 @@ checkpoint=True, num_attention_heads=32, embed_split_hidden=True, - vocab_size=103168, + vocab_size=92544, embed_grad_scale=1, parallel_output=False, hidden_size=4096, @@ -68,8 +68,9 @@ layer_norm_epsilon=1e-5, use_flash_attn=False, num_chunks=1, + no_bias=True, ), - model_type="INTERNLM", + model_type="INTERNLM2_PUBLIC", alert_address=None, monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), grad_scaler=dict( @@ -178,7 +179,7 @@ def train_check_output(args): optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) - train_dl, dataset_types = build_train_loader_with_data_type() + _, dataset_types = build_train_loader_with_data_type() metric = AccPerplex( device=get_current_device(), @@ -226,9 +227,9 @@ def train_check_output(args): if gpc.is_rank_for_log(): standard_output_with_fa = torch.load( - os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa.pt") + os.path.join(share_path, "quailty_assurance/7B_no_flash_attention/output_with_fa_internlm2.pt") ) - tensor1 = standard_output_with_fa + tensor1 = standard_output_with_fa[0][0] tensor2 = output[0][0][0] if torch.equal(tensor1, tensor2): diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 4094c582..d1db7496 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -25,25 +25,26 @@ from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer -CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_sft.py") -INTERNLM1_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/model_ckpt") +CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_internlm2.py") +INTERNLM2_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss_pri/model_ckpt") TOTAL_STEPS = 10 LOSS_SPIKE_LIMIT = 1.5 LOSS_DEVIATION_LIMIT = 0.02 # dp_size = 4 BASELINE_LOSS_LIST = [ - 11.63298511505127, - 7.82645320892334, - 6.727725505828857, - 6.182029724121094, - 5.395882606506348, - 5.394383430480957, - 5.053952217102051, - 4.742049694061279, - 4.629276752471924, - 4.616517543792725, + 12.362918853759766, + 12.404379844665527, + 12.348219871520996, + 12.194982528686523, + 11.80469036102295, + 11.573806762695312, + 10.045475006103516, + 9.660882949829102, + 9.172087669372559, + 4.799427032470703, ] + cur_loss_list = [] internlm_accelerator = get_accelerator() @@ -59,7 +60,7 @@ def train( enable_sp: bool = False, save_ckpt: bool = False, load_ckpt: bool = False, - model_type: str = "INTERNLM", + model_type: str = "INTERNLM2_PUBLIC", optimizer_ver: str = "v1", pp_mode: str = "1F1B", ): @@ -67,24 +68,31 @@ def train( config = Config.from_file(CONFIG_FILE_PATH) # init setting - config.data.total_steps = TOTAL_STEPS + config.data.total_steps = 50000 config.data.fixed_random_dataset_seqlen = False - config.lr_scheduler.total_steps = TOTAL_STEPS + config.data.micro_num = 4 + config.data.micro_bsz = 2 + config.lr_scheduler.total_steps = config.data.total_steps config.model_type = model_type config.ckpt.load_ckpt_folder = None config.ckpt.load_ckpt_info = None config.ckpt.auto_resume = False - total_steps = config.data.total_steps + total_steps = TOTAL_STEPS skip_batches = config.data.skip_batches label_smoothing = config.loss.label_smoothing + config.parallel.zero1 = dict(size=-1) + config.parallel.tensor = dict(size=1, mode="mtp") + config.parallel.pipeline = dict(size=1, interleaved_overlap=True, mode="1f1b") + config.parallel.weight = dict(size=1, overlap=True) if optimizer_ver == "v2": config.hybrid_zero_optimizer.use_split_tensor_optim = True config.all_gather_size = 512 * 1024 * 1024 + config.model.checkpoint = True # update ckpt config - if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False: - config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test") + if model_type == "INTERNLM2_PUBLIC" and tp_mode != "isp" and interleaved is False: + config.ckpt.load_ckpt_info = dict(path=INTERNLM2_CKPT_PATH, content=("model",), ckpt_type="internlm2_test") if save_ckpt: config.ckpt.enable_save_ckpt = True @@ -213,7 +221,7 @@ def train( train_iter = iter(train_dl) - if model_type == "INTERNLM": + if model_type == "INTERNLM2_PUBLIC": data_path = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/data_batch_4DP") data_batch = torch.load(f"{data_path}/{gpc.get_local_rank(ParallelMode.DATA)}_data_batch.pt") @@ -222,7 +230,7 @@ def train( empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) timer("one-batch").start() - if model_type == "INTERNLM": + if model_type == "INTERNLM2_PUBLIC": if batch_count >= 10: batch = data_batch[batch_count - 10] else: @@ -296,7 +304,6 @@ def check_loss_spike(): def check_loss_accuracy(): if gpc.is_rank_for_log(): - print(f"cur_loss_list:{cur_loss_list}", flush=True) for cur, target in zip(cur_loss_list, BASELINE_LOSS_LIST): assert ( abs(cur - target) < LOSS_DEVIATION_LIMIT @@ -464,16 +471,16 @@ def test_training_with_isp(): global CONFIG_FILE_PATH, BASELINE_LOSS_LIST CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ - 11.595988273620605, - 7.988386154174805, - 6.821506500244141, - 6.2768449783325195, - 5.478013515472412, - 5.4622697830200195, - 5.162247180938721, - 4.854615211486816, - 4.744818210601807, - 4.75523567199707, + 12.225811004638672, + 12.103824615478516, + 12.223844528198242, + 11.87704849243164, + 11.651590347290039, + 11.629219055175781, + 10.242591857910156, + 9.768388748168945, + 9.330610275268555, + 5.505439758300781, ] # model training @@ -516,12 +523,3 @@ def test_training_llama2(): CONFIG_FILE_PATH = "./configs/7B_llama2.py" train(dp_size=8, model_type="LLAMA2") - - -@pytest.mark.training_internlm2 -def test_training_internlm2(): - # update config file - global CONFIG_FILE_PATH - CONFIG_FILE_PATH = "./configs/7B_internlm2.py" - - train(dp_size=8, model_type="INTERNLM2_PUBLIC") diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index b33cf4c3..7926bae5 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -20,7 +20,7 @@ from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState, Trainer # noqa: E402 +from internlm.core.trainer import Trainer, TrainState # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -60,6 +60,7 @@ def check_model_weights(model, ckpt_path, total_equal=False): + model = model.model model1_dict = torch.load(ckpt_path, map_location="cuda") model2_dict = model.state_dict() @@ -214,13 +215,14 @@ def main(args): # check model init weights if hasattr(gpc.config, "CHECK_INIT") and gpc.config.CHECK_INIT == 1: ckpt_name = ( - f"model_dp{gpc.get_local_rank(ParallelMode.DATA)}" + f"model" f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}" f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" ) - ckpt_path = os.path.join(os.environ["share_path"], "quailty_assurance/7B_init_dp=2_tp=2_pp=2_ckpt", ckpt_name) + ckpt_path = os.path.join( + os.environ["share_path"], "quailty_assurance/7B_internlm2_init_dp=2_tp=2_pp=2_ckpt/init", ckpt_name + ) check_model_weights(model, ckpt_path, total_equal=True) - with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): @@ -327,12 +329,17 @@ def main(args): ) # check model weights - if gpc.is_rank_for_log() and batch_count > 0 and batch_count % 100 == 0: + if batch_count > 0 and batch_count % 100 == 0: + ckpt_name = ( + f"model" + f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}" + f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}.pt" + ) ckpt_path = os.path.join( os.environ["share_path"], - "quailty_assurance/7B_model_weights_ckpt", + "quailty_assurance/7B_internlm2_init_dp=2_tp=2_pp=2_ckpt", str(batch_count), - "model_tp0_pp0.pt", + ckpt_name, ) check_model_weights(model, ckpt_path) From 71c32c82bc1e67f6937a3852c6d945400fbf3ed6 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 10 Dec 2024 14:24:58 +0800 Subject: [PATCH 05/12] fix(gmm): change communicator.grad_hook to async (#371) --- internlm/model/modules/linear.py | 38 ++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 856e6ba0..6f190268 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -344,6 +344,9 @@ def forward( ctx.compute_weight_gradient = weight.requires_grad ctx.backend = backend + saved_x = None if ctx.compute_weight_gradient is False else x + ctx.save_for_backward(saved_x, weight, batch_sizes) + if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() @@ -358,8 +361,7 @@ def forward( output = torch.matmul(x, weight) - saved_x = None if ctx.compute_weight_gradient is False else x - ctx.save_for_backward(saved_x, weight, batch_sizes) + assert len(output.shape) == len(x.shape) return output @@ -372,6 +374,14 @@ def backward(ctx, grad_output): x, weight, batch_sizes = ctx.saved_tensors grad_input, grad_weight = None, None + if grad_output.numel() == 0: + if ctx.needs_input_grad[1]: + grad_weight = torch.zeros_like(weight) + if ctx.needs_input_grad[0]: + grad_input = torch.zeros_like(x) + + return grad_input, grad_weight, None, None, None, None, None + if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient if backend == "gmm": @@ -450,6 +460,8 @@ def forward( saved_x = None if ctx.compute_weight_gradient is False else x ctx.save_for_backward(saved_x, weight, batch_sizes) + assert len(output.shape) == len(x.shape) + return output @staticmethod @@ -461,20 +473,28 @@ def backward(ctx, grad_output): backend = ctx.backend full_weight_shape = ctx.full_weight_shape - grad_output = grad_output.contiguous() - - total_weight = communicator.weight_hook(weight, module=module) - total_weight = total_weight.reshape(full_weight_shape) - grad_input, grad_weight = None, None if grad_output.numel() == 0: + if ctx.needs_input_grad[1]: + total_weight_shape = torch.Size( + (full_weight_shape.numel() // full_weight_shape[-1], full_weight_shape[-1]) + ) + grad_weight = torch.zeros(total_weight_shape, dtype=weight.dtype, device=weight.device) + grad_weight, grad_weight_sync = communicator.grad_hook( + grad_weight, async_op=True, module=module, is_bias=False + ) if ctx.needs_input_grad[0]: grad_input = torch.zeros_like(x) if ctx.needs_input_grad[1]: - grad_weight = torch.zeros_like(total_weight).reshape(-1, full_weight_shape[-1]) - grad_weight, _ = communicator.grad_hook(grad_weight, async_op=False, module=module, is_bias=False) + grad_weight_sync.wait() return grad_input, grad_weight, None, None, None, None, None + grad_output = grad_output.contiguous() + + total_weight = communicator.weight_hook(weight, module=module) + total_weight = total_weight.reshape(full_weight_shape) + grad_input, grad_weight = None, None + if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient if backend == "gmm": From cd53c323957a08845e969d012c02cc27db5bf645 Mon Sep 17 00:00:00 2001 From: cx <759046501@qq.com> Date: Tue, 10 Dec 2024 14:25:26 +0800 Subject: [PATCH 06/12] Feat/refactor process group (#358) --- configs/57B_qwen2_MoE.py | 6 +- configs/8x22B_mixtral.py | 6 +- configs/8x7B_mixtral.py | 6 +- internlm/core/context/parallel_context.py | 110 +++--- .../core/context/process_group_initializer.py | 356 +++++++++++++++++- 5 files changed, 415 insertions(+), 69 deletions(-) diff --git a/configs/57B_qwen2_MoE.py b/configs/57B_qwen2_MoE.py index 0fd67603..abfb0a5b 100644 --- a/configs/57B_qwen2_MoE.py +++ b/configs/57B_qwen2_MoE.py @@ -190,7 +190,6 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -201,15 +200,14 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=1, overlap=True), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True, memory_pool=True), + expert_weight=dict(size=1, overlap=True), ) cudnn_deterministic = False diff --git a/configs/8x22B_mixtral.py b/configs/8x22B_mixtral.py index 56206bd4..debd423b 100644 --- a/configs/8x22B_mixtral.py +++ b/configs/8x22B_mixtral.py @@ -191,7 +191,6 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -202,15 +201,14 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=1, overlap=True), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True, memory_pool=True), + expert_weight=dict(size=1, overlap=True), ) cudnn_deterministic = False diff --git a/configs/8x7B_mixtral.py b/configs/8x7B_mixtral.py index f589c967..322342ea 100644 --- a/configs/8x7B_mixtral.py +++ b/configs/8x7B_mixtral.py @@ -191,7 +191,6 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -202,15 +201,14 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. - 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=1, overlap=True), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True, memory_pool=True), + expert_weight=dict(size=1, overlap=True), ) cudnn_deterministic = False diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 989b1c00..8d74e3c6 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -21,8 +21,14 @@ from internlm.utils.timeout import LLM_NCCL_TIMEOUT from internlm.utils.utils import TensorParallelMode -from . import process_group_initializer as pgroup_initializer -from .process_group_initializer import ParallelMode +from .process_group_initializer import ( + GroupConfig, + ParallelMode, + create_parallel_process_groups, + create_single_process_group, + generate_2d_attn_process_group, + generate_parallel_group_configs, +) from .random import add_seed, get_seeds, set_mode # for layernorm @@ -633,60 +639,60 @@ def init_parallel_groups(self): self.check_sanity() - initializer_args = [ - rank, - world_size, - self.weight_parallel_size, - self.weight_data_parallel_size, - self.sequence_parallel_size, - self.data_parallel_size, - self.pipeline_parallel_size, - self.tensor_parallel_size, - self.zero1_parallel_size, - self.nettest_parallel_size, - self.expert_parallel_size, - self.expert_tensor_parallel_size, - self.expert_weight_parallel_size, - self.expert_data_parallel_size, - parallel_config.sequence_2D, - ] - - # run initialization of different process groups - initializers = [] - if "gqa" in parallel_config and parallel_config["gqa"] is True: - initializers.append(pgroup_initializer.Initializer_GQA(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Data(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args)) + parallel_sizes = { + ParallelMode.TENSOR: self.tensor_parallel_size, + ParallelMode.SEQUENCE: self.sequence_parallel_size, + ParallelMode.PIPELINE: self.pipeline_parallel_size, + ParallelMode.DATA: self.data_parallel_size, + ParallelMode.ZERO1: self.zero1_parallel_size, + ParallelMode.WEIGHT: self.weight_parallel_size, + ParallelMode.WEIGHT_DATA: self.weight_data_parallel_size, + ParallelMode.NETTEST: self.nettest_parallel_size, + ParallelMode.EXPERT: self.expert_parallel_size, + ParallelMode.EXPERT_WEIGHT: self.expert_weight_parallel_size, + ParallelMode.EXPERT_TENSOR: self.expert_tensor_parallel_size, + ParallelMode.EXPERT_DATA: self.expert_data_parallel_size, + } + + # process groups for parallelism. + enable_moe = self.config.model.get("num_experts", 1) > 1 + tp_mode = "mtp" if isinstance(parallel_config.tensor, int) else parallel_config.tensor.get("mode", "mtp") + is_fsdp = False if isinstance(parallel_config.zero1, int) else parallel_config.zero1.get("fsdp", False) + parallel_strategy = "fsdp" if is_fsdp else tp_mode + group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe) + group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False) + + # process group for extra gqa tensor parallel. if ( - isinstance(parallel_config["tensor"], dict) - and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name + "num_kv_attention_heads" in self.config.model + and self.config.model.num_kv_attention_heads < self.tensor_parallel_size ): - initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args)) - else: - initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) - if isinstance(parallel_config["zero1"], dict) and parallel_config["zero1"].get("fsdp", False): - initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args)) - initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) - if self.pipeline_parallel_size > 1: - initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) - if self.config.model.get("num_experts", 1) > 1: - if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": - initializers.append(pgroup_initializer.Initializer_Expert_Weight_Data(*initializer_args)) - else: - initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args)) + group_results.append( + create_single_process_group( + world_size, + rank, + GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads), + ) + ) + + # process group for network test. + group_results.append( + create_single_process_group( + world_size, + rank, + GroupConfig(ParallelMode.NETTEST, self.nettest_parallel_size, allow_partial_group=True), + ) + ) + + # process group for isp 2D attn. if parallel_config.sequence_2D.get("enable", False) is True: - initializers.append(pgroup_initializer.Initializer_2D_SEQUENCE_PARALLEL(*initializer_args)) + group_results.extend( + generate_2d_attn_process_group(world_size, rank, parallel_config.sequence_2D, parallel_sizes) + ) - for initializer in initializers: - parallel_setting = initializer.init_dist_group() - if isinstance(parallel_setting, list): - for args in parallel_setting: - self._register_dist(*args) - else: - self._register_dist(*parallel_setting) + # register process groups + for result in group_results: + self._register_dist(*result) def is_initialized(self, parallel_mode: ParallelMode): """Returns a boolean value indicating whether `parallel_mode` is initialized diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index fbc3e07a..5313ad92 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -6,11 +6,16 @@ import math from abc import ABC, abstractmethod from enum import Enum +from functools import reduce +from typing import Any, Dict, List, Optional, Tuple, Union import torch.distributed as dist +from internlm.utils.logger import get_logger from internlm.utils.timeout import LLM_NCCL_TIMEOUT +logger = get_logger(__file__) + # parallel modes class ParallelMode(Enum): @@ -81,6 +86,349 @@ class ParallelMode(Enum): DKV_INTRA_WINDOW = "dkv_intra_window" +class GroupConfig: + """config for initialze a process group""" + + def __init__( + self, + mode: ParallelMode, + size: int, + anonymous: bool = False, + allow_partial_group: bool = False, + subgroups: Optional[List["GroupConfig"]] = None, + ) -> None: + self.mode = mode + self.size = size + self.anonymous = anonymous + self.allow_partial_group = allow_partial_group + self.subgroups = subgroups if subgroups is not None else [] + + self._early_subgroup_checking() + + def _early_subgroup_checking(self) -> None: + if len(self.subgroups) == 0: + return + + group_target_size = reduce(lambda x, y: x * y, [_g.size for _g in self.subgroups]) + assert group_target_size <= self.size, "subgroup size should less than father group" + + +def init_cpu_group(group, ranks, use_cpu: bool = False): + if use_cpu: + cpu_group = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) if dist.get_backend() != "gloo" else group + ) + else: + cpu_group = None + + return cpu_group + + +def get_group_ranks( + global_ranks_or_sizes: Union[int, List[int]], + cur_group_size: int, + pre_group_size: int, + allow_partial_group: bool = False, +): + group_ranks = [] + + if isinstance(global_ranks_or_sizes, list): + global_size = len(global_ranks_or_sizes) + global_ranks = global_ranks_or_sizes + else: + global_size = global_ranks_or_sizes + global_ranks = None + + real_global_size = global_size + + if allow_partial_group: + global_size = math.ceil(global_size / cur_group_size) * cur_group_size + + assert global_size % cur_group_size == 0, "err1" + + def _get_local_starts(): + for i in range(0, global_size, cur_group_size * pre_group_size): + for j in range(pre_group_size): + yield 0 + i + j + + for start in _get_local_starts(): + ranks = [ + start + i * pre_group_size for i in range(cur_group_size) if start + i * pre_group_size < real_global_size + ] + if global_ranks is not None: + ranks = [global_ranks[_idx] for _idx in ranks] + + group_ranks.append(ranks) + + assert len(group_ranks) == global_size // cur_group_size, f"{group_ranks}, {global_size}, {cur_group_size}" + + return group_ranks + + +def _create_parallel_process_groups( + global_ranks_or_sizes: int, + self_rank: int, + pre_group_size: int, + group_configs: List[GroupConfig], + with_cpu_group: bool = False, +): + group_results = [] + + for group in group_configs: + if group.anonymous is True: + pre_group_size = pre_group_size * group.size + continue + + group_ranks, accelerator_group = None, None + all_group_ranks = get_group_ranks(global_ranks_or_sizes, group.size, pre_group_size, group.allow_partial_group) + + for idx, ranks in enumerate(all_group_ranks): + _pg = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) + if self_rank in ranks: + group_ranks, accelerator_group = all_group_ranks[idx], _pg + else: + dist.destroy_process_group(_pg) + + if group_ranks is None: + pre_group_size = pre_group_size * group.size + continue + + cpu_group = init_cpu_group(accelerator_group, group_ranks, with_cpu_group) + + group_results.append( + (group_ranks.index(self_rank), len(group_ranks), accelerator_group, cpu_group, group_ranks, group.mode) + ) + + if len(group.subgroups) > 0: + subgroup_results = _create_parallel_process_groups( + global_ranks_or_sizes, self_rank, pre_group_size, group.subgroups, with_cpu_group + ) + group_results.extend(subgroup_results) + + pre_group_size = pre_group_size * group.size + + return group_results + + +def create_parallel_process_groups( + world_size: int, self_rank: int, group_configs: List[List[GroupConfig]], with_cpu_group: bool = False +): + group_results = [] + already_allocated_group = {} + + def _checker(order: str, result: Tuple[Any]) -> bool: + parallel_mode = result[-1] + + if parallel_mode not in already_allocated_group: + already_allocated_group[parallel_mode] = (order, result) + return True + else: + # check + ranks_in_group_idx = -2 + pre_order, pre_allocate_result = already_allocated_group[parallel_mode] + + error_msg = ( + f"The ranks allocated for {parallel_mode} are inconsistent in config {pre_order} and {order}: " + + f"{pre_allocate_result[ranks_in_group_idx]} != {result[ranks_in_group_idx]}" + ) + assert pre_allocate_result[ranks_in_group_idx] == result[ranks_in_group_idx], error_msg + + # release process group + dist.destroy_process_group(result[2]) # accelerator_group + if with_cpu_group: + dist.destroy_process_group(result[3]) # cpu_group + + return False + + for order, group_config in group_configs: + pre_group_size = 1 + + results = _create_parallel_process_groups( + world_size, + self_rank, + pre_group_size, + group_config, + with_cpu_group, + ) + + for result in results: + if _checker(order, result) is True: + group_results.append(result) + + return group_results + + +def create_single_process_group( + world_size: int, self_rank: int, config: GroupConfig, with_cpu_group: bool = False, pre_anonymous_size: int = 1 +): + pre_group_size = pre_anonymous_size + + return _create_parallel_process_groups( + world_size, + self_rank, + pre_group_size, + [config], + with_cpu_group, + )[0] + + +MTP_GROUP_ORDER = [ParallelMode.TENSOR, ParallelMode.DATA, ParallelMode.PIPELINE] +MTP_MOE_GROUP_ORDER = [ParallelMode.EXPERT_TENSOR, ParallelMode.EXPERT, ParallelMode.EXPERT_DATA, ParallelMode.PIPELINE] +ISP_SP_GROUP_ORDER = [ParallelMode.TENSOR, ParallelMode.DATA, ParallelMode.PIPELINE] +ISP_WP_GROUP_ORDER = [ParallelMode.WEIGHT, ParallelMode.WEIGHT_DATA, ParallelMode.PIPELINE] +ISP_MOE_GROUP_ORDER = [ParallelMode.EXPERT_WEIGHT, ParallelMode.EXPERT, ParallelMode.EXPERT_DATA, ParallelMode.PIPELINE] +FSDP_ORDER = [ParallelMode.DATA] # TODO: should we support moe for fsdp? + +SUBGROUP_SPEC = { + "mtp": { + ParallelMode.DATA: [ParallelMode.ZERO1], + }, + "isp": { + ParallelMode.WEIGHT_DATA: [ParallelMode.ZERO1], + }, # TODO: WEIGHT_ZERO1 + "fsdp": { + ParallelMode.DATA: [ParallelMode.ZERO3_DP, ParallelMode.ZERO1], + }, +} + + +def generate_parallel_group_configs( + parallel_strategy: str, parallel_sizes: Dict[ParallelMode, int], enable_moe: bool = False +) -> List[List[GroupConfig]]: + + group_configs = [] + subgroup_spec = SUBGROUP_SPEC.get(parallel_strategy, SUBGROUP_SPEC["mtp"]) + + def _recurse_generater(order: List[ParallelMode]): + config = [] + + for mode in order: + # disable pp process group for compatibility when pp size is 1. + anonymous = mode is ParallelMode.PIPELINE and parallel_sizes[mode] == 1 + + if mode not in subgroup_spec: + config.append(GroupConfig(mode, parallel_sizes[mode], anonymous)) + else: + config.append( + GroupConfig( + mode, parallel_sizes[mode], anonymous, subgroups=_recurse_generater(subgroup_spec[mode]) + ) + ) + + return config + + if parallel_strategy == "isp": + # sp configs + group_configs.append(("isp-sp", _recurse_generater(ISP_SP_GROUP_ORDER))) + # wp configs + group_configs.append(("isp-wp", _recurse_generater(ISP_WP_GROUP_ORDER))) + if enable_moe: + group_configs.append(("isp-moe", _recurse_generater(ISP_MOE_GROUP_ORDER))) + elif parallel_strategy == "fsdp": + group_configs.append(("fsdp", _recurse_generater(FSDP_ORDER))) + else: # 3d parallel: mtp, msp, fsp + group_configs.append(("3d", _recurse_generater(MTP_GROUP_ORDER))) + if enable_moe: + group_configs.append(("3d-moe", _recurse_generater(MTP_MOE_GROUP_ORDER))) + + return group_configs + + +def generate_2d_attn_process_group( + world_size: int, + self_rank: int, + config: Dict[str, Any], + parallel_sizes: Dict[ParallelMode, int], + with_cpu_group: bool = False, +): + + assert config.context_size * config.head_size == parallel_sizes[ParallelMode.SEQUENCE] + assert world_size % parallel_sizes[ParallelMode.SEQUENCE] == 0 + + if config.window_size >= 8 or config.window_size == config.context_size: + logger.warning("interleaved is forced False when window size > 8 or equals context size.") + config.interleaved = False + + if config.device_placement_strategy.head_first and config.head_size > 1: + logger.warning("interleaved is forced False when head_first is True and head size > 1.") + config.interleaved = False + + group_results = [] + sp_pre_group_size = 1 + for parallel_mode in ISP_SP_GROUP_ORDER: + if parallel_mode is ParallelMode.TENSOR: # assert sp is tp. + break + else: + sp_pre_group_size *= parallel_sizes[parallel_mode] + + # head and context process groups. + if config.device_placement_strategy.head_first: + group_configs = [ + GroupConfig(ParallelMode.HEAD, config.head_size), + GroupConfig(ParallelMode.CONTEXT, config.context_size), + ] + context_results_index = 1 + else: + group_configs = [ + GroupConfig(ParallelMode.CONTEXT, config.context_size), + GroupConfig(ParallelMode.HEAD, config.head_size), + ] + context_results_index = 0 + + group_results.extend( + _create_parallel_process_groups(world_size, self_rank, sp_pre_group_size, group_configs, with_cpu_group) + ) + + # window process groups. + window_num = config.context_size // config.window_size + cp_pre_group_size = 1 if context_results_index == 0 else config.head_size + every_context_ranks = get_group_ranks(world_size, config.context_size, cp_pre_group_size) + + def _gen_window_process_groups(context_ranks: List[int]): + if not config.device_placement_strategy.interleaved: + window_ranks = context_ranks + else: + _indexes = [ + j * 2 + i * config.window_size if i % 2 == 0 else j * 2 + 1 + (i - 1) * config.window_size + for i in range(window_num) + for j in range(config.window_size) + ] + window_ranks = [context_ranks[_i] for _i in _indexes] + + group_results.extend( + _create_parallel_process_groups( + window_ranks, + self_rank, + 1, + [ + GroupConfig(ParallelMode.INTRA_WINDOW, config.window_size), + GroupConfig(ParallelMode.INTER_WINDOW, window_num), + ], + with_cpu_group, + ) + ) + group_results.extend( + _create_parallel_process_groups( + window_ranks, + self_rank, + 1, + [ + GroupConfig(ParallelMode.DKV_INTRA_WINDOW, config.window_size), + GroupConfig(ParallelMode.DKV_INTER_WINDOW, window_num), + ], + with_cpu_group, + ) + ) + + for context_ranks in every_context_ranks: + _gen_window_process_groups(context_ranks) + + # print(get_group_ranks(window_ranks, config.window_size, 1)) + # print(get_group_ranks(window_ranks, window_num, config.window_size)) + + return group_results + + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -1124,11 +1472,10 @@ class Initializer_GQA(ProcessGroupInitializer): """ def __init__(self, *args, **kwargs): + self.num_attention_heads = kwargs.pop("num_attention_heads") + self.num_kv_attention_heads = kwargs.pop("num_kv_attention_heads") super().__init__(*args, **kwargs) - # TODO: should adapt to general case - self.num_kv_attention_heads = 8 - self.NUM_ATTENTION_HEAD = 32 - self.kv_head_repeats_num = self.NUM_ATTENTION_HEAD // self.num_kv_attention_heads + self.kv_head_repeats_num = self.tensor_parallel_size // self.num_kv_attention_heads self.num_kv_group_per_tp = self.num_kv_attention_heads self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size @@ -1159,7 +1506,6 @@ def init_dist_group(self, use_cpu: bool = False): group_world_size = None mode = ParallelMode.GQA - # TODO: consider PP for i in range(self.data_parallel_size): for j in range(self.num_kv_group_per_tp): ranks = [ From ae2243c100dc253fb78430fbd851c1bc7a8ae4f4 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Tue, 10 Dec 2024 14:25:56 +0800 Subject: [PATCH 07/12] feat(dataloader): refine implementation of mocked and megatron dataloader (#344) Co-authored-by: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com> --- internlm/data/build_dataloader.py | 36 +++++++--- internlm/data/megatron/__init__.py | 2 - internlm/data/megatron/batch_sampler.py | 62 ----------------- internlm/data/megatron/collaters.py | 56 ++++++--------- internlm/data/megatron/dataset.py | 90 +++++-------------------- internlm/data/mocked/batch_sampler.py | 36 ++++++++-- internlm/data/mocked/dataset.py | 4 +- internlm/train/pipeline.py | 90 +++++++++++++++---------- 8 files changed, 150 insertions(+), 226 deletions(-) delete mode 100644 internlm/data/megatron/batch_sampler.py diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index 64da9539..e99bbfc7 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -2,12 +2,13 @@ import subprocess from functools import partial +import torch import torch.distributed as dist from torch.utils.data import ConcatDataset, DataLoader +from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.data.megatron.batch_sampler import MegatronBatchSampler from internlm.data.megatron.collaters import megatron_collate_fn from internlm.data.megatron.dataset import build_megatron_dataset from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler @@ -41,8 +42,8 @@ from internlm.utils.logger import get_logger from internlm.utils.utils import DataType -# global llm logger logger = get_logger(__file__) +internlm_accelerator = get_accelerator() def get_tokenized_train_loader_items(data_cfg): @@ -156,10 +157,14 @@ def get_streaming_train_loader_items(data_cfg): def get_megatron_train_loader_items(data_cfg): + assert data_cfg.get( + "pack_sample_into_one", False + ), "megatron dataloader curently only supports pack_sample_into_one=True" try: from internlm.data.megatron import helpers # noqa # pylint: disable=W0611 except ImportError: - if gpc.is_rank_for_log(): + # Compile dynamic library on-demand + if gpc.get_global_rank() % internlm_accelerator.device_count() == 0: subprocess.run( # noqa # pylint: disable=W1510 [ "g++", @@ -173,23 +178,28 @@ def get_megatron_train_loader_items(data_cfg): "internlm/data/megatron/helpers.cpp", "-o", "internlm/data/megatron/helpers.so", - ] + ], ) + torch.distributed.barrier() + + # NOTICE: Currently we only support single megatron dataset, a.k.a., single .bin and .idx + # Megatron dataset (.bin and.idx) should be generated by Megatron-LM tools/preprocess_data.py + # https://github.com/NVIDIA/Megatron-LM/blob/main/tools/preprocess_data.py train_ds = build_megatron_dataset( data_prefix=data_cfg.train_folder, - data_impl=data_cfg.get("data_impl", "infer"), - splits_string="1.0, 0.0, 0.0", - train_valid_test_num_samples=[9600000, 0, 0], seq_len=data_cfg.seq_len, seed=data_cfg.get("seed", 1024), - skip_warmup=True, ) - train_sampler = MegatronBatchSampler( - total_samples=len(train_ds), - consumed_samples=0, + train_sampler = StaticBatchSampler( + train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], batch_size=data_cfg.micro_num * data_cfg.micro_bsz, + rampup_batch_size=data_cfg.rampup_batch_size, + micro_bsz=data_cfg.micro_bsz, + seed=data_cfg.get("seed", 1024), drop_last=True, + data_rank=gpc.get_local_rank(ParallelMode.DATA), + data_world_size=gpc.get_world_size(ParallelMode.DATA), ) train_collate_fn = partial( @@ -203,14 +213,18 @@ def get_mock_train_loader_items(data_cfg): assert data_cfg.get( "pack_sample_into_one", False ), "mocked dataloader curently only supports pack_sample_into_one=True" + train_ds = MockedDataset( train_folder=data_cfg.train_folder, micro_bsz=data_cfg.micro_bsz, micro_num=data_cfg.micro_num, seq_len=data_cfg.seq_len, ) + train_sampler = MockedSequentialBatchSampler(train_ds, data_cfg.micro_num) + train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz) + return train_ds, train_sampler, train_collate_fn diff --git a/internlm/data/megatron/__init__.py b/internlm/data/megatron/__init__.py index 5e447596..5405f6f8 100644 --- a/internlm/data/megatron/__init__.py +++ b/internlm/data/megatron/__init__.py @@ -1,9 +1,7 @@ -from .batch_sampler import MegatronBatchSampler from .collaters import megatron_collate_fn from .dataset import build_megatron_dataset __all__ = [ - "MegatronBatchSampler", "build_megatron_dataset", "megatron_collate_fn", ] diff --git a/internlm/data/megatron/batch_sampler.py b/internlm/data/megatron/batch_sampler.py deleted file mode 100644 index 049cfcf7..00000000 --- a/internlm/data/megatron/batch_sampler.py +++ /dev/null @@ -1,62 +0,0 @@ -import copy -import math - -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc - - -class MegatronBatchSampler: - """ - MegatronBatchSampler - """ - - def __init__(self, total_samples, consumed_samples, batch_size, drop_last=True): - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.batch_size = batch_size - self.drop_last = drop_last - - self.dp_rank = gpc.get_local_rank(ParallelMode.DATA) - self.dp_size = gpc.get_world_size(ParallelMode.DATA) - - # Sanity checks. - assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) - assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( - self.consumed_samples, self.total_samples - ) - assert self.batch_size > 0 - assert self.dp_size > 0 - assert self.dp_rank < self.dp_size, "dp_rank should be smaller than dp_size: {}, " "{}".format( - self.dp_rank, self.dp_size - ) - - def __len__(self): - if self.drop_last and self.total_samples % self.dp_size != 0: - return math.ceil(self.total_samples - self.dp_size) / self.dp_size - else: - return math.ceil(self.total_samples / self.dp_size) - - def get_start_end_idx(self): - start_idx = self.dp_rank * self.batch_size - end_idx = start_idx + self.batch_size - return start_idx, end_idx - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.batch_size * self.dp_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - - # TODO: implement copy method that compatible with InternEvo trainstate - def copy(self): - return copy.deepcopy(self) diff --git a/internlm/data/megatron/collaters.py b/internlm/data/megatron/collaters.py index 252bc289..c6ffc80e 100644 --- a/internlm/data/megatron/collaters.py +++ b/internlm/data/megatron/collaters.py @@ -2,48 +2,36 @@ def megatron_collate_fn(batch, micro_num, micro_bsz, seq_len): - - input_ids_result = [[] for _ in range(micro_num)] - labels_result = [[] for _ in range(micro_num)] - cu_seqlens = [] + input_ids_list = [[] for _ in range(micro_num)] + labels_list = [[] for _ in range(micro_num)] cu_seqlens_list = [] - indexes = [] indexes_list = [] - for i, item in enumerate(batch): - assert i < micro_num * micro_bsz - seq_len_list = item["text"] - assert len(seq_len_list) == seq_len + 1 - - micro_bsz_index = i % micro_bsz - micro_num_index = i // micro_bsz - - input_ids_result[micro_num_index].append(seq_len_list[:-1]) - labels_result[micro_num_index].append(seq_len_list[1:]) - - cu_seqlens.append(seq_len * micro_bsz_index) - indexes = indexes + list(range(seq_len)) + assert len(batch) == micro_bsz * micro_num + for idx, b in enumerate(batch): + tokens = b["text"] + # The length of megatron preprocessed data samples is (seq_len + 1) + # So we use the first seq_len tokens as input and the last seq_len tokens as shifted labels + assert len(tokens) == seq_len + 1 + micro_bsz_index = idx % micro_bsz + micro_num_index = idx // micro_bsz + input_ids_list[micro_num_index].append(tokens[:-1]) + labels_list[micro_num_index].append(tokens[1:]) if micro_bsz_index == micro_bsz - 1: - input_ids_result[micro_num_index] = torch.cat( - [torch.from_numpy(arr).long() for arr in input_ids_result[micro_num_index]], dim=0 + # Since megatron data sample is numpy format, we need to convert it to tensor and concate within micro batch + input_ids_list[micro_num_index] = torch.cat( + [torch.from_numpy(arr) for arr in input_ids_list[micro_num_index]], dim=0 ) - labels_result[micro_num_index] = torch.cat( - [torch.from_numpy(arr).long() for arr in labels_result[micro_num_index]], dim=0 + labels_list[micro_num_index] = torch.cat( + [torch.from_numpy(arr) for arr in labels_list[micro_num_index]], dim=0 ) - cu_seqlens.append(seq_len * micro_bsz) - cu_seqlens_list.append(torch.IntTensor(cu_seqlens)) - cu_seqlens = [] - indexes_list.append(torch.IntTensor(indexes)) - indexes = [] - - input_ids = torch.stack(input_ids_result) - labels = torch.stack(labels_result) - indexes = torch.stack(indexes_list) + cu_seqlens_list.append(torch.IntTensor([i * seq_len for i in range(micro_bsz + 1)])) + indexes_list.append(torch.IntTensor(list(range(seq_len)) * micro_bsz)) return { - "input_ids": input_ids, + "input_ids": torch.stack(input_ids_list), "cu_seqlens": cu_seqlens_list, - "indexes": indexes, + "indexes": torch.stack(indexes_list), "type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64), - }, labels + }, torch.stack(labels_list) diff --git a/internlm/data/megatron/dataset.py b/internlm/data/megatron/dataset.py index 7dba0294..88f4697b 100644 --- a/internlm/data/megatron/dataset.py +++ b/internlm/data/megatron/dataset.py @@ -1,5 +1,6 @@ # adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/gpt_dataset.py # adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/indexed_dataset.py + import hashlib import os import struct @@ -764,82 +765,25 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): return indexed_dataset -def get_train_valid_test_split_(splits_string, size): - """Get dataset splits from comma or '/' separated string list.""" - - splits = [] - if splits_string.find(",") != -1: - splits = [float(s) for s in splits_string.split(",")] - elif splits_string.find("/") != -1: - splits = [float(s) for s in splits_string.split("/")] - else: - splits = [float(splits_string)] - while len(splits) < 3: - splits.append(0.0) - splits = splits[:3] - splits_sum = sum(splits) - assert splits_sum > 0.0 - splits = [split / splits_sum for split in splits] - splits_index = [0] - for index, split in enumerate(splits): - splits_index.append(splits_index[index] + int(round(split * float(size)))) - diff = splits_index[-1] - size - for index in range(1, len(splits_index)): - splits_index[index] -= diff - assert len(splits_index) == 4 - assert splits_index[-1] == size - return splits_index - - def build_megatron_dataset( data_prefix, - data_impl, - splits_string, - train_valid_test_num_samples, seq_len, seed, - skip_warmup, - return_doc_ids=False, - *, - data_cache_path=None, ): - # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) - - total_num_of_documents = indexed_dataset.sizes.shape[0] - splits = get_train_valid_test_split_(splits_string, total_num_of_documents) - - # Print stats about the splits. - print_rank_0(" > dataset split:") - - def print_split_stats(index, name): - print_rank_0(" {}:".format(name)) - print_rank_0( - " document indices in [{}, {}) total of {} " - "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) - ) - - print_split_stats(0, "train") - - def build_dataset(index, name): - dataset = None - if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - dataset = GPTDataset( - name, - data_prefix, - documents, - indexed_dataset, - splits_string, - train_valid_test_num_samples[index], - seq_len, - seed, - return_doc_ids, - data_cache_path=data_cache_path, - ) - return dataset - - train_dataset = build_dataset(0, "train") - - return train_dataset + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl="infer", skip_warmup=True) + + # GPT dataset. + return GPTDataset( + name="train", + data_prefix=data_prefix, + documents=np.arange(start=0, stop=indexed_dataset.sizes.shape[0], step=1, dtype=np.int32), + indexed_dataset=indexed_dataset, + splits_string="1.0, 0.0, 0.0", # proportion of dataset for train/valid/test, we set 1.0 for train only + num_samples=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA) + * gpc.config.data.total_steps, # total number of train samples + seq_length=seq_len, + seed=seed, + ) diff --git a/internlm/data/mocked/batch_sampler.py b/internlm/data/mocked/batch_sampler.py index 737566fa..62f3dcea 100644 --- a/internlm/data/mocked/batch_sampler.py +++ b/internlm/data/mocked/batch_sampler.py @@ -1,24 +1,46 @@ -import copy - - class MockedSequentialBatchSampler: """ - MockedSequentialBatchSampler + A batch sampler that yields sequential batches of a specified size from a dataset. """ def __init__(self, train_ds, micro_num): + """ + Initialize the MockedSequentialBatchSampler. + + Args: + train_ds: The training dataset to sample from. + micro_num (int): The number of micro batches. + """ self.train_ds = train_ds self.micro_num = micro_num + self.batch_count = 0 + self.num_consumed_samples_in_epoch = 0 + def __iter__(self): num_samples = len(self.train_ds) - for start in range(0, num_samples, self.micro_num): + while self.num_consumed_samples_in_epoch < num_samples: + start = self.num_consumed_samples_in_epoch end = min(start + self.micro_num, num_samples) + self.batch_count += 1 + self.num_consumed_samples_in_epoch += end - start yield list(range(start, end)) def __len__(self): return (len(self.train_ds) + self.micro_num - 1) // self.micro_num - # TODO: implement copy method that compatible with InternEvo trainstate + def state_dict(self): + states = { + "batch_count": self.batch_count, + "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, + } + return states + + def load_state_dict(self, states): + self.batch_count = states["batch_count"] + self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] + def copy(self): - return copy.deepcopy(self) + copy_sampler = MockedSequentialBatchSampler(self.train_ds, self.micro_num) + copy_sampler.load_state_dict(self.state_dict()) + return copy_sampler diff --git a/internlm/data/mocked/dataset.py b/internlm/data/mocked/dataset.py index 0d0e488e..88020a78 100644 --- a/internlm/data/mocked/dataset.py +++ b/internlm/data/mocked/dataset.py @@ -108,7 +108,7 @@ def __init__(self, train_folder: str, micro_bsz: int, micro_num: int, seq_len: i ] # simple sanity check: ensure loaded per-step data is equivalent to saved per-step data - self.sanity_check(tokens_list, labels_list) + self._sanity_check(tokens_list, labels_list) def __len__(self) -> int: return len(self.db_tokens) @@ -122,7 +122,7 @@ def __getitem__(self, idx: int) -> Dict[str, List[int]]: "type_ids": [0] * (self.micro_bsz * self.seq_len), } - def sanity_check(self, tokens_list: List[torch.Tensor], labels_list: List[torch.Tensor]): + def _sanity_check(self, tokens_list: List[torch.Tensor], labels_list: List[torch.Tensor]): tokens_list_tocheck = [] for i in range(len(self.db_tokens)): tokens_list_tocheck += self.db_tokens[i] diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 5907a4e3..79e9caf4 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -164,7 +164,7 @@ def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): - def _check_module_pure_dp_wdp(name, module): # pylint: disable=W0613 + def _check_module_pure_dp(name, module): # pylint: disable=W0613 for param in module.parameters(): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) @@ -220,11 +220,13 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) for _chunk in unwrap_naive_amp(model): - # special case for pure dp or pure wdp mode - if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size( - ParallelMode.WEIGHT_DATA - ) == gpc.get_world_size(ParallelMode.GLOBAL): - _check_module_func = _check_module_pure_dp_wdp + # special case for pure dp mode + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): + _check_module_func = _check_module_pure_dp else: _check_module_func = _check_module # set param parallel attribute @@ -953,22 +955,34 @@ def traverse(module): def inject_config(model: nn.Module) -> None: + # Compatibility for Vision-Language Model if hasattr(model.config, "text_config"): - model_config = model.config.text_config + llm_cfg = model.config.text_config else: - model_config = model.config - gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size - gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = model_config.hidden_size - gpc.config.model.num_layers = gpc.config.NUM_LAYER = model_config.num_hidden_layers - gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = model_config.num_attention_heads - gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size + llm_cfg = model.config + gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = llm_cfg.vocab_size + gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = llm_cfg.hidden_size + gpc.config.model.num_layers = gpc.config.NUM_LAYER = llm_cfg.num_hidden_layers + # Compatibility for Mamba + if hasattr(llm_cfg, "num_attention_heads"): + gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = llm_cfg.num_attention_heads + gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = llm_cfg.intermediate_size / llm_cfg.hidden_size # For models that use GQA - if hasattr(model_config, "num_key_value_heads"): - gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = model_config.num_key_value_heads + if hasattr(llm_cfg, "num_key_value_heads"): + gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None: - # get inject_info + """ + Inject model helper functions. + + Args: + model (Union[nn.Module, nn.ModuleList]): + For built-in models, it is nn.Module for no pp and nn.ModuleList for pp. + For injected models, it is nn.Module. + inject_info (Optional[Dict]): configurations for injected_models. + """ + # parse inject_info if inject_info is not None: inject = inject_info.get("inject", False) interactive = inject_info.get("interactive", False) @@ -990,31 +1004,37 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt "norm": inject_norm, } + # inject config + if inject: + inject_config(model) + if not isinstance(model, nn.ModuleList): model = [model] - - # inject modules for _chunk in model: - if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size( - ParallelMode.WEIGHT_DATA - ) == gpc.get_world_size(ParallelMode.GLOBAL): + # Special case for pure dp mode: skip + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): continue + # In-place replacement or check for modules: "embed", "linear", "norm" + # (1) If inject=True, in-place replacement + # (2) If inject=False, check for mod in modules: inject_funcs[mod](_chunk, inject, interactive) - - # reset parameters and move model to device + # reset parameters if needed, model should have reset_parameters() method + if reset_params: + _chunk.reset_parameters() for _chunk in model: - if inject: - if reset_params: - _chunk.reset_parameters() + # If model is initialized on cpu, model should be moved to cuda device after injection + if not next(_chunk.parameters()).is_cuda: _chunk.to(get_current_device()) - # inject configs - if inject: - inject_config(model[0]) - if gpc.is_rank_for_log(): - logger.info( - f"inject is enabled, please check the model carefully, " - f"if there are any problems, please report issue to us. " - f"The injected model is \n {model}" - ) + # print injected model + if inject and gpc.is_rank_for_log(): + logger.info( + f"inject is enabled, please check the model carefully, " + f"if there are any problems, please report issue to us. " + f"The injected model is \n {model}" + ) From 5ad2eb02fb5be2196e505600fef459185070d1e3 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 10 Dec 2024 15:23:33 +0800 Subject: [PATCH 08/12] fix(pp): fix pp get tensor shape err and layernorm input dtype err (#378) --- internlm/core/scheduler/pipeline_scheduler_1f1b.py | 6 +++++- internlm/model/modeling_internlm.py | 4 ++-- internlm/model/modeling_internlm2.py | 2 +- internlm/model/modeling_llama.py | 2 +- internlm/model/modeling_mixtral.py | 4 ++-- internlm/model/modeling_moe.py | 4 ++-- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler_1f1b.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py index 4864c77f..289bc37d 100644 --- a/internlm/core/scheduler/pipeline_scheduler_1f1b.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -35,7 +35,11 @@ def get_tensor_shape(): if not gpc.is_initialized(ParallelMode.PIPELINE): return None - if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"): + if ( + hasattr(gpc.config.data, "seq_len") + and hasattr(gpc.config.data, "micro_bsz") + and hasattr(gpc.config.model, "hidden_size") + ): if gpc.config.data.use_packed_dataset and gpc.is_evaluating is False: if gpc.config.parallel.sequence_parallel: sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index e2837724..ebf7d0b0 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -195,7 +195,7 @@ def _forward(self, hidden_states, *args, **kwargs): def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) _residual = _dropped - _hidden_states = self.norm1(_residual.float()) + _hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: @@ -212,7 +212,7 @@ def _dropout_and_norm_attn(_hidden_states): def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.norm2(_residual.float()) + _hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index fedd27c4..69da0837 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -257,7 +257,7 @@ def _dropout_and_norm_attn(_residual, _hidden_states): def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.ffn_norm(_residual.to(torch.float32)) + _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) return _residual, _hidden_states diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 46fc9c03..56b88e83 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -246,7 +246,7 @@ def _dropout_and_norm_attn(_residual, _hidden_states): def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.ffn_norm(_residual.to(torch.float32)) + _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) return _residual, _hidden_states diff --git a/internlm/model/modeling_mixtral.py b/internlm/model/modeling_mixtral.py index 844b5081..8e8767ce 100644 --- a/internlm/model/modeling_mixtral.py +++ b/internlm/model/modeling_mixtral.py @@ -214,7 +214,7 @@ def _forward(self, hidden_states, *args, **kwargs): def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) _residual = _dropped - _hidden_states = self.norm1(_residual.float()) + _hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: @@ -231,7 +231,7 @@ def _dropout_and_norm_attn(_hidden_states): def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.norm2(_residual.float()) + _hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 964b268e..f40d35f3 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -205,7 +205,7 @@ def _forward(self, hidden_states, *args, **kwargs): def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) _residual = _dropped - _hidden_states = self.norm1(_residual.float()) + _hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: @@ -222,7 +222,7 @@ def _dropout_and_norm_attn(_hidden_states): def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.norm2(_residual.float()) + _hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype)) return _residual, _hidden_states if self.dropout_selective_checkpoint: From e60a50a7270d1bc43e3808e612f28db830f0121d Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 17 Dec 2024 14:59:52 +0800 Subject: [PATCH 09/12] feat(parallel_context.py): remove useless gqa process group (#390) --- internlm/core/context/parallel_context.py | 13 --- .../core/context/process_group_initializer.py | 81 ------------------- 2 files changed, 94 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 8d74e3c6..f4751f59 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -662,19 +662,6 @@ def init_parallel_groups(self): group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe) group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False) - # process group for extra gqa tensor parallel. - if ( - "num_kv_attention_heads" in self.config.model - and self.config.model.num_kv_attention_heads < self.tensor_parallel_size - ): - group_results.append( - create_single_process_group( - world_size, - rank, - GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads), - ) - ) - # process group for network test. group_results.append( create_single_process_group( diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 5313ad92..1e805738 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -74,9 +74,6 @@ class ParallelMode(Enum): # real data parallel for isp ISP_DATA = "isp_data" - # grouped query attention - GQA = "gqa" - # sequence 2D parallel HEAD = "head" CONTEXT = "context" @@ -1454,84 +1451,6 @@ def init_dist_group(self, use_cpu: bool = False): return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode -class Initializer_GQA(ProcessGroupInitializer): - """A ProcessGroupInitializer for allreduce kv gradients with common attention head. - - Args: - rank (int): The rank of current process. - world_size (int): Size of whole communication world. - weight_parallel_size (int): Size of model weight parallel. - weight_data_parallel_size (int): Size of data parallel for common weight. - sequence_parallel_size (int): Size of data sequence parallel. - data_parallel_size (int): Size of data parallel. - pipeline_parallel_size (int): Size of pipeline parallel. - tensor_parallel_size (int): Size of tensor parallel. - zero1_parallel_size (int): Size of zero1 parallel. - nettest_parallel_size (int): Size of net testing parallel. - expert_parallel_size (int): Size of expert parallel. - """ - - def __init__(self, *args, **kwargs): - self.num_attention_heads = kwargs.pop("num_attention_heads") - self.num_kv_attention_heads = kwargs.pop("num_kv_attention_heads") - super().__init__(*args, **kwargs) - self.kv_head_repeats_num = self.tensor_parallel_size // self.num_kv_attention_heads - self.num_kv_group_per_tp = self.num_kv_attention_heads - self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size - - assert self.world_size % self.tensor_parallel_size == 0 - assert self.world_size % (self.pipeline_parallel_size * self.tensor_parallel_size) == 0 - assert self.pipeline_parallel_size == 1 - - def init_dist_group(self, use_cpu: bool = False): - """Initialize weight's data parallel groups, and assign local_ranks and groups to each gpu. - - Returns: - Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): - A WEIGHT_DATA parallelism's information tuple. - - n=128 sp=32 wp=64 zo1=1 with nopp - sp groups: [0-31] [32-63] [64-95] [96-127] - wp groups: [0-63] [64-127] - kv_head groups: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15] - [16,17,18,19] [20,21,22,23] [24,25,26,27] [28,29,30,31] - ... - ... - ... - """ - local_rank = None - ranks_in_group = None - process_group = None - cpu_group = None - group_world_size = None - mode = ParallelMode.GQA - - for i in range(self.data_parallel_size): - for j in range(self.num_kv_group_per_tp): - ranks = [ - i * self.tensor_parallel_size + j * self.kv_head_repeats_num + k - for k in range(self.kv_head_repeats_num) - ] - group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) - if use_cpu: - group_cpu = ( - dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) - if dist.get_backend() != "gloo" - else group - ) - else: - group_cpu = None - - if self.rank in ranks: - local_rank = ranks.index(self.rank) - group_world_size = len(ranks) - process_group = group - cpu_group = group_cpu - ranks_in_group = ranks - - return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode - - class Initializer_2D_SEQUENCE_PARALLEL(ProcessGroupInitializer): """ A ProcessGroupInitializer for 2D sequence parallel. From 0ec6cdc162b6632e67ac6821b10a8ec86c3197d7 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 17 Dec 2024 17:37:45 +0800 Subject: [PATCH 10/12] feat(isp): support switch for launch ag and forward overlap per module (#381) --- configs/7B_MoE4_sft.py | 12 +- configs/7B_isp_sft.py | 6 +- internlm/core/parallel/comm/isp.py | 160 +++++++++++------- internlm/initialize/launch.py | 8 +- internlm/model/ops/attention.py | 2 + ...zag_ring_flash_attn_with_sliding_window.py | 3 - tests/test_training/test_loss.py | 2 +- 7 files changed, 126 insertions(+), 67 deletions(-) diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index c558427c..4037c031 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -183,6 +183,10 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -193,14 +197,18 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), + weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True), + expert_weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), ) cudnn_deterministic = False diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index de99f917..ad68082d 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -186,6 +186,10 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. sequence_2D (dict): 1. enable: bool, whether enable the 2D sequence parallel or not. 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). @@ -205,7 +209,7 @@ zero1=dict(size=-1), tensor=dict(size=2, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=4, overlap=True), + weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), sequence_2D=dict( enable=False, head_size=2, diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index d4950c75..7e722c2f 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -266,7 +266,6 @@ def __init__( dtype: torch.dtype = torch.half, device: torch.device = None, activation_checkpointing: float = 0.0, - module_shapes: Dict[str, torch.Size] = None, ) -> None: self.dtype = dtype if device is None: @@ -274,7 +273,6 @@ def __init__( else: self.device = device self.activation_checkpointing = activation_checkpointing - self.module_shapes = module_shapes class ISPOverlapState: @@ -285,7 +283,7 @@ class ISPOverlapState: def __init__(self) -> None: self.num_blocks: int = 0 self.ckpt_block_num: int = 0 - self.isp_outs: List[nn.Module] = [] + self.isp_prefetch_launch_module: List[nn.Module] = [] self.isp_modules: List[nn.Module] = [] self.index_to_isp_modules: Dict[int, nn.Module] = {} self.index_to_block: Dict[int, nn.Module] = {} @@ -315,8 +313,9 @@ def __init__( self.is_moe = is_moe self.is_forward = True self.reduce_scatter_handlers = {} - self._module_shapes = {} self._forward_prefetch_prerequisites = [] + self._forward_overlap_per = self._get_forward_overlap_granularity() + self._launch_before_module = self._get_launch_before_module() # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -324,7 +323,7 @@ def __init__( # inner interface variables of overlap state. self._num_blocks = None self._ckpt_block_num = None - self._isp_outs = None + self._isp_prefetch_launch_module = None self._isp_modules = None # key: isp module; value: module global all-gather op handle self._weight_global_handle = None @@ -351,7 +350,32 @@ def __init__( self._register_sync_parameters_hook() # switch to chunk 0 at first. self.switch_current_model_chunk(0) - self.model_conf.module_shapes = self._module_shapes + + def _get_launch_before_module(self): + if self.is_moe is True: + _launch_before = gpc.config.parallel.expert_weight.get("launch_allgather_before", "wo") + else: + _launch_before = gpc.config.parallel.weight.get("launch_allgather_before", "wo") + + if _launch_before == "wqkv": + return ["wqkv", "Wqkv", "qkv", "q_a_proj", "q_proj"] + elif _launch_before == "attn": + return ["attn"] + elif _launch_before == "wo": + return ["out_proj", "wo"] + elif _launch_before == "w1": + return ["w1", "fused_w1_w3"] + else: + assert False, "launch module should be in ['wqkv', 'attn', 'wo', 'w1']" + + def _get_forward_overlap_granularity(self): + if self.is_moe is True: + _overlap_granularity = gpc.config.parallel.expert_weight.get("forward_overlap_per", "layer") + else: + _overlap_granularity = gpc.config.parallel.weight.get("forward_overlap_per", "layer") + + assert _overlap_granularity in ["module", "layer"] + return _overlap_granularity def _parse_model_structure(self, cid: int, model: nn.Module) -> None: self._overlap_states[cid] = ISPOverlapState() @@ -359,6 +383,13 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None: def get_model(obj: nn.Module) -> nn.Module: return get_model(obj.model) if hasattr(obj, "model") else obj + def is_allgather_launch_module(name, module): + return ( + hasattr(module, "is_attn_cls") + and getattr(module, "is_attn_cls") + and self._launch_before_module == ["attn"] + ) or (name.split(".")[-1] in self._launch_before_module) + # Important: only works for llama-class models children_name = get_model(model).named_children() for _, children in children_name: @@ -369,18 +400,12 @@ def get_model(obj: nn.Module) -> nn.Module: self._overlap_states[cid].index_to_isp_modules[idx] = [] self._overlap_states[cid].index_to_block[idx] = block for name, child in block.named_modules(): - if name.split(".")[-1] in ["out_proj", "wo"]: - self._overlap_states[cid].isp_outs.append(child) - self._overlap_states[cid].module_to_index[child] = idx + if is_allgather_launch_module(name, child): + self._overlap_states[cid].isp_prefetch_launch_module.append(child) if isinstance(child, (ParallelLinearWithCommExt)): if is_moe_param(child.weight) != self.is_moe: continue - if name not in self._module_shapes: - weight_parallel_size = dist.get_world_size(self.process_group) - origin_shape = tuple( - [child.weight.shape[0] * weight_parallel_size] + list(child.weight.shape[1:]) - ) - self._module_shapes[name] = torch.Size(origin_shape) + self._overlap_states[cid].module_to_index[child] = idx self._overlap_states[cid].isp_modules.append(child) self._overlap_states[cid].index_to_isp_modules[idx].append(child) @@ -403,25 +428,28 @@ def get_model(obj: nn.Module) -> nn.Module: self._overlap_states[cid].num_blocks = len(self._overlap_states[cid].index_to_isp_modules) def _all_gather_module_weight(self, module): + assert module not in self._bias_global_output and module not in self._weight_global_output with_bias = module.bias is not None # submit the all-gather communication for weight and bias. if with_bias: - bias_output, bias_handle = all_gather_raw( - module.bias, + if module not in self._bias_global_output: + bias_output, bias_handle = all_gather_raw( + module.bias, + self.process_group, + async_op=True, + ) + self._bias_global_handle[module] = bias_handle + self._bias_global_output[module] = bias_output + + if module not in self._weight_global_output: + weight_output, weight_handle = all_gather_raw( + module.weight, self.process_group, async_op=True, ) - self._bias_global_handle[module] = bias_handle - self._bias_global_output[module] = bias_output - - weight_output, weight_handle = all_gather_raw( - module.weight, - self.process_group, - async_op=True, - ) - self._weight_global_handle[module] = weight_handle - self._weight_global_output[module] = weight_output + self._weight_global_handle[module] = weight_handle + self._weight_global_output[module] = weight_output def _all_gather_block_weight(self, block_index: int): block = self._index_to_block[block_index] @@ -463,23 +491,20 @@ def _pre_forward_hook_for_first_block(self, *args): # pylint: disable=W0613 """ prefetch weight for block 0 before forward. """ - if self.is_forward is True: + if self._forward_overlap_per == "layer" and self.is_forward is True: self._all_gather_block_weight(0) - def _pre_forward_hook_for_last_ckpt_block(self, *args): # pylint: disable=W0613 - if self.is_forward is False: - self._all_gather_block_weight(self._ckpt_block_num - 1) - - def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613 + def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args): # pylint: disable=W0613 block_index = self._module_to_index[module] - if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False: - if block_index - 1 >= 0: - self._all_gather_block_weight(block_index - 1) - else: - # start the all-gather for next block - if block_index + 1 < self._num_blocks: - self._all_gather_block_weight(block_index + 1) + if self._forward_overlap_per == "layer": + if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False: + if block_index - 1 >= 0: + self._all_gather_block_weight(block_index - 1) + else: + # start the all-gather for next block + if block_index + 1 < self._num_blocks: + self._all_gather_block_weight(block_index + 1) def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: @@ -487,6 +512,32 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._wait_handle(module) + if self._forward_overlap_per == "module": + # start the all-gather for next module + # 1.forward prefetch for next module + module_index = self._isp_modules.index(module) + module_layer_id = self._module_to_index[module] + if module_index + 1 < len(self._isp_modules) and self.is_forward is True: + next_module = self._isp_modules[module_index + 1] + self._all_gather_module_weight(next_module) + + # 2.recompute forward prefetch for next module + if self.is_forward is False: + if module_index + 1 < len(self._isp_modules): + next_module = self._isp_modules[module_index + 1] + next_module_layer_id = self._module_to_index[next_module] + if module_layer_id == next_module_layer_id: + self._all_gather_module_weight(next_module) + # if current module is the last module in current layer, prefetch previous layer's first module + elif module_layer_id - 1 >= 0: + next_module = self._index_to_isp_modules[module_layer_id - 1][0] + self._all_gather_module_weight(next_module) + else: + # if current module is the last module, prefetch previous layer's first module + if module_layer_id - 1 >= 0: + next_module = self._index_to_isp_modules[module_layer_id - 1][0] + self._all_gather_module_weight(next_module) + def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_handle(module) @@ -515,29 +566,24 @@ def _register_sync_parameters_hook(self) -> None: register forward hooks and backward hooks for isp modules. """ # register forward hooks - # 1. register pre_forward_hook @block_0 to prefetch for block 0 - # 2. register pre_forward_hook @block_(ckpt_block_num-1) to prefetch for the last ckpt block - # 3. register pre_forward_hook @out_proj module to prefetch for next block, - # notice that next block's all_gather op should be after current block's all_to_all op - # 4. register pre_forward_hook @isp_module to wait handle for current module - # 5. register post_forward_hook @isp_module to release resource + # 1. register pre_forward_hook @block_0 to prefetch weight for block 0. + # 2. register pre_forward_hook @prefetch_launch_module to prefetch weight for next block, + # when forward overlap granularity is 'layer'. + # 3. register pre_forward_hook @isp_module to wait handle for current module, + # and prefetch weight for next module when forward overlap granularity is 'module'. + # 4. register post_forward_hook @isp_module to release memory resource. self._index_to_block[0].register_forward_pre_hook(self._pre_forward_hook_for_first_block) - if self._ckpt_block_num >= 1: - self._index_to_block[self._ckpt_block_num - 1].register_forward_pre_hook( - self._pre_forward_hook_for_last_ckpt_block - ) - - for out_proj in self._isp_outs: - out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj) + for module in self._isp_prefetch_launch_module: + module.register_forward_pre_hook(self._pre_forward_hook_for_prefetch_launch_module) for module in self._isp_modules: module.register_forward_pre_hook(self._pre_forward_hook_for_module) module.register_forward_hook(self._post_forward_hook_for_module) # register backward hooks - # 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module - # 2. register post_backward_hook @isp_module to release resource + # 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module. + # 2. register post_backward_hook @isp_module to release memory resource. if self._ckpt_block_num < self._num_blocks: for module in self._isp_modules: module.register_full_backward_pre_hook(self._pre_backward_hook_for_module) @@ -556,7 +602,7 @@ def communication_mode(self) -> str: return "wp" def switch_current_model_chunk(self, chunk_id: int) -> None: - self._isp_outs = self._overlap_states[chunk_id].isp_outs + self._isp_prefetch_launch_module = self._overlap_states[chunk_id].isp_prefetch_launch_module self._isp_modules = self._overlap_states[chunk_id].isp_modules self._weight_global_handle = self._overlap_states[chunk_id].weight_global_handle self._bias_global_handle = self._overlap_states[chunk_id].bias_global_handle @@ -872,9 +918,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten q, kv = _SeqAllToAll.apply(self.spg, [2, 3], [1, 1], q, kv) - torch.cuda.synchronize() context = self.local_attn(q, kv, *args, **kwargs) - torch.cuda.synchronize() context = _SeqAllToAll.apply(self.spg, 1, 2, context) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1ac8ef31..35b3d646 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -94,13 +94,17 @@ def args_sanity_check(): gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) if "weight" not in gpc.config.parallel: - gpc.config.parallel._add_item("weight", dict(size=1, overlap=False)) + gpc.config.parallel._add_item( + "weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer") + ) if "expert" not in gpc.config.parallel: gpc.config.parallel._add_item("expert", dict(size=-1, no_tp=False)) if "expert_weight" not in gpc.config.parallel: - gpc.config.parallel._add_item("expert_weight", dict(size=1, overlap=False)) + gpc.config.parallel._add_item( + "expert_weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer") + ) if isinstance(gpc.config.parallel.pipeline, int): pp = gpc.config.parallel.pipeline diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index c2639569..604ea77a 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -886,6 +886,8 @@ class SelfAttention(nn.Module): attention_dropout (float): Dropout rate for attention scores. Defaults to 0.0. """ + is_attn_cls = True + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, layer_idx=0): super().__init__() self.causal = causal diff --git a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py b/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py index 6d531158..5c22fed3 100644 --- a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py +++ b/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py @@ -481,7 +481,6 @@ def forward( @staticmethod def backward(ctx, dout, *args): # pylint: disable=W0613 - torch.cuda.synchronize() q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = zigzag_double_ring_flash_attn_backward( @@ -504,8 +503,6 @@ def backward(ctx, dout, *args): # pylint: disable=W0613 deterministic=ctx.deterministic, ) - torch.cuda.synchronize() - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index d1db7496..2fd8ad4c 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -109,7 +109,7 @@ def train( config.hybrid_zero_optimizer.overlap_sync_grad = False config.parallel.pipeline = dict(size=pp_size, mode=pp_mode) - config.parallel.weight = dict(size=wp_size, overlap=True) + config.parallel.weight = dict(size=wp_size, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer") if interleaved is True: config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True, mode=pp_mode) config.model.num_chunks = num_chunks From 141e9eb53df337477296f450d4f6928ff6420f23 Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Tue, 17 Dec 2024 20:28:46 +0800 Subject: [PATCH 11/12] feat(loss)/add different operator types for cross_entropy (#386) --- configs/7B_MoE4_sft.py | 14 + configs/7B_internlm2.py | 18 +- configs/7B_isp_sft.py | 16 + internlm/core/trainer_builder.py | 10 +- internlm/initialize/launch.py | 19 +- internlm/model/losses/__init__.py | 4 +- internlm/model/losses/ce_loss.py | 72 ++- internlm/model/metrics.py | 1 + internlm/model/ops/cross_entropy.py | 421 ++++-------------- .../model/ops/cross_entropy_ops/__init__.py | 11 + .../ops/cross_entropy_ops/apex_naive_loss.py | 77 ++++ .../ops/cross_entropy_ops/py_naive_loss.py | 83 ++++ .../py_vocab_parallel_loss.py | 160 +++++++ .../sequence_parallel_loss.py | 121 +++++ tests/test_infer/test_trainer_generate.py | 4 +- .../test_forward_output_no_fa.py | 4 +- tests/test_training/test_load_ckpt_loss.py | 4 +- tests/test_training/test_loss.py | 4 +- tests/test_training/test_no_fa_train_temp.py | 4 +- tests/test_training/test_norm_weight.py | 4 +- .../test_swap_nb_loss_and_gradnorm.py | 4 +- tests/test_training/train_CI.py | 4 +- 22 files changed, 682 insertions(+), 377 deletions(-) create mode 100644 internlm/model/ops/cross_entropy_ops/__init__.py create mode 100644 internlm/model/ops/cross_entropy_ops/apex_naive_loss.py create mode 100644 internlm/model/ops/cross_entropy_ops/py_naive_loss.py create mode 100644 internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py create mode 100644 internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 4037c031..8d8acc40 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -103,6 +103,20 @@ clip_grad_norm=1.0, ) + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. + loss = dict( label_smoothing=0, moe_loss_coeff=0.1, diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 97758bba..51741703 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -98,9 +98,21 @@ clip_grad_norm=1.0, ) -loss = dict( - label_smoothing=0, -) + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. +loss = dict(label_smoothing=0, op_type="py_vocab_parallel") adam = dict( lr=1e-4, diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index ad68082d..39c78660 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -108,8 +108,24 @@ clip_grad_norm=1.0, ) + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. + loss = dict( label_smoothing=0, + op_type="flash_vocab_parallel", ) adam = dict( diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d..2b82bc1f 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -16,7 +16,7 @@ from internlm.data.train_state import get_train_state from internlm.eval.evaluation import evaluate_on_val_dls from internlm.initialize.initialize_trainer import initialize_trainer -from internlm.model.losses.ce_loss import FlashGPTLMLoss +from internlm.model.losses.ce_loss import InternLoss from internlm.model.metrics import AccPerplex from internlm.monitor.monitor import send_alert_message from internlm.train.pipeline import ( @@ -172,9 +172,11 @@ def _read_config(self, config_path: str) -> list: with open(config_path, "r") as f: return f.readlines() - def _initialize_criterion(self) -> FlashGPTLMLoss: - return FlashGPTLMLoss( - parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing + def _initialize_criterion(self) -> InternLoss: + return InternLoss( + parallel_output=gpc.config.model.parallel_output, + label_smoothing=gpc.config.loss.label_smoothing, + op_type=gpc.config.loss.op_type, ) def _initialize_checkpoint_manager( diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 35b3d646..c8b16516 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -351,17 +351,6 @@ def args_sanity_check(): if "use_flash_attn" not in gpc.config.model: gpc.config.model._add_item("use_flash_attn", True) - old_parallel_output = gpc.config.model.get("parallel_output", None) - # Try to change user setting - if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: - gpc.config.model.update({"parallel_output": False}) - if old_parallel_output is True and gpc.is_rank_for_log(): - logger.warning( - "'parallel_output' is converted from 'True' to 'False'." - "Because 'parallel_output' only support by FlashCrossEntropyLoss." - "Please make sure you are using flash attention in cuda device." - ) - if "MoE" in gpc.config.get("model_type", ModelType.INTERNLM.name): if "num_experts" not in model: model._add_item("num_experts", 1) @@ -449,6 +438,9 @@ def args_sanity_check(): ]: gpc.config.parallel.sequence_parallel = True + if gpc.config.model.get("parallel_output", False) is False: + logger.warning("When enable sequence parallel, it recommend to enable parallel_output") + # set default value for weight parallel if gpc.config.parallel["weight"].get("overlap", None) is None: gpc.config.parallel["weight"]["overlap"] = False @@ -583,6 +575,11 @@ def args_sanity_check(): gpc.config.data.use_packed_dataset is False ), "only unpacked data is supported when using 2D sequence parallel." + # loss operator type + loss_cfg = gpc.config.loss + if loss_cfg.get("op_type", None) is None: + loss_cfg._add_item("op_type", "py_vocab_parallel") + def launch( config: Union[str, Path, Config, Dict], diff --git a/internlm/model/losses/__init__.py b/internlm/model/losses/__init__.py index 58287815..5d6c8db3 100644 --- a/internlm/model/losses/__init__.py +++ b/internlm/model/losses/__init__.py @@ -1,5 +1,5 @@ -from .ce_loss import FlashGPTLMLoss +from .ce_loss import InternLoss __all__ = [ - "FlashGPTLMLoss", + "InternLoss", ] diff --git a/internlm/model/losses/ce_loss.py b/internlm/model/losses/ce_loss.py index 69e09d2f..5b2a380e 100644 --- a/internlm/model/losses/ce_loss.py +++ b/internlm/model/losses/ce_loss.py @@ -1,36 +1,61 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - +import torch from torch import nn -from internlm.core.context import global_context as gpc +from internlm.accelerator import get_accelerator from internlm.model.ops.cross_entropy import new_cross_entropy -from internlm.utils.logger import get_logger -logger = get_logger(__file__) +internlm_accelerator = get_accelerator() -class FlashGPTLMLoss(nn.Module): - """ - Loss function for flash GPT Language Model. +class InternLoss(nn.Module): + """We use a base class to wrap different CrossEntropy implementations + and unify input and output parameters. + + This class is designed not to rely on gpc, making it easy to transplant. + + Different variants of CrossEntropy, with supporting parallel computation and inplace operations. + + If parallel_output is False, the output will gather head's output, only 'FlashCrossEntropyLoss' and + 'CrossEntropyApexVocabParallel' support it. """ - def __init__(self, parallel_output=True, label_smoothing=0): + def __init__( + self, + parallel_output=False, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + inplace_backward=True, + op_type="py_vocab_parallel", + ) -> None: super().__init__() if label_smoothing is not None: if label_smoothing != 0: - if gpc.is_rank_for_log(): - print(f"use label_smoothing: {label_smoothing}") + print(f"use label_smoothing: {label_smoothing}", flush=True) else: label_smoothing = 0 self.label_smoothing = label_smoothing + + self.reduction = reduction + self.ignore_index = ignore_index + self.op_type = op_type + + assert self.reduction in [ + "mean", + "none", + ], f"Only support reduction is mean/none, but the passed in reduction is {self.reduction}" + + # In order to facilitate the calculation of loss for different datasets, we set reduction as 'none', + # and do loss reduction ourselves. self.loss_fn = new_cross_entropy( - reduction="mean", - label_smoothing=self.label_smoothing, + op_type=op_type, + ignore_index=ignore_index, + label_smoothing=label_smoothing, parallel_output=parallel_output, - inplace_backward=True, + inplace_backward=inplace_backward, + reduction="none", ) def forward(self, *args): @@ -44,9 +69,18 @@ def forward(self, *args): raise RuntimeError(f"The number of criterion inputs are:{len(args)}") shift_logits = logits.contiguous().view(-1, logits.size(-1)) shift_labels = labels.contiguous().view(-1) - loss = self.loss_fn( - shift_logits, shift_labels - ) # There is no need to consider the ignore_index problem here, because the loss calculation will be - # calculated through the calculation range, and -100 must be outside this range, so there is no problem + + with torch.autocast(device_type=internlm_accelerator.get_backend_name()): + loss_list = self.loss_fn( + shift_logits, shift_labels + ) # There is no need to consider the ignore_index problem here, because the loss calculation will be + # # calculated through the calculation range, and -100 must be outside this range, so there is no problem + + cond = shift_labels != self.ignore_index + if self.reduction == "mean": + # This loss is only for one dp rank. + loss = loss_list.sum() / (cond).sum() + else: + loss = loss_list return loss diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index af52858f..a7f6c966 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -305,6 +305,7 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: reduction="none", parallel_output=gpc.config.model.parallel_output, inplace_backward=True, + op_type=gpc.config.loss.op_type, ) self.scatter_sum = scatter_sum_impl diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index 82a2da70..99bf1e04 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -6,354 +6,131 @@ This file implements support for the cross entropy operators. """ +from enum import Enum + import torch -import torch.distributed as dist from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.ops.cross_entropy_ops import ( + CrossEntropyApexVocabParallel, + CrossEntropyLossApex, + CrossEntropyPython, +) from internlm.utils.logger import get_logger logger = get_logger(__file__) internlm_accelerator = get_accelerator() -# Adapted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/ \ -# sequence_parallel/cross_entropy.py -class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): - """ - Cross Entropy module for isp. - """ - - @staticmethod - def forward(ctx, vocab_seq_parallel_logits, target, reduction, label_smoothing=0.0): # pylint: disable=W0613 - sp_size = gpc.get_world_size(ParallelMode.TENSOR) - - # reshape - # vocab_seq_parallel_logits: [B * (S/P), V] -> [B, S/P, V] - # target: [B * S/P] -> [B, S/P] - bsz = gpc.config.data.micro_bsz if gpc.config.data.use_packed_dataset is False else 1 - vocab_seq_parallel_logits = vocab_seq_parallel_logits.view(bsz, -1, gpc.config.model.vocab_size) - target = target.view(bsz, -1) - - # transpose - # vocab_seq_parallel_logits: [B, S/P, V] -> [S/P, B, V] - # target: [B, S/P] -> [S/P, B] - # return: [S, B] - vocab_seq_parallel_logits = vocab_seq_parallel_logits.transpose(0, 1).contiguous() - target = target.transpose(0, 1).contiguous() - - ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_size - batch_size = vocab_seq_parallel_logits.size(1) - - # Need softmax for backward - softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) - ctx.vocab_size = vocab_seq_parallel_logits.size(2) - loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction="none") - - loss_all = torch.empty( - ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device - ) +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, group=gpc.get_group(ParallelMode.DATA)) + averaged_losses = averaged_losses / gpc.get_world_size(ParallelMode.DATA) - torch.distributed.all_gather_into_tensor(loss_all, loss, group=gpc.get_group(ParallelMode.TENSOR)) + return averaged_losses - # [s b] => [b, s] - loss_all = loss_all.transpose(0, 1).contiguous() - ctx.save_for_backward(softmax, target) +class CrossEntropyOpType(Enum): + torch_naive = 1 # CrossEntropy from torch + flash_vocab_parallel = 2 # VocabParallel CorssEntropy from flash_attn + apex_naive = 3 # CrossEntropy from apex + py_vocab_parallel = 4 # self-implemented VocabParallel CrossEntropy + py_naive = 5 # self-implemented CrossEntropy + # sequence_parallel = 6 # self-implemented SequenceParallel CrossEntropy - return loss_all - @staticmethod - def backward(ctx, grad_output): - softmax, target = ctx.saved_tensors +cross_entropy_op_name_map = { + "torch_naive": CrossEntropyOpType.torch_naive, + "flash_vocab_parallel": CrossEntropyOpType.flash_vocab_parallel, + "apex_naive": CrossEntropyOpType.apex_naive, + "py_vocab_parallel": CrossEntropyOpType.py_vocab_parallel, + "py_naive": CrossEntropyOpType.py_naive, + # "sequence_parallel": CrossEntropyOpType.sequence_parallel, +} - # transpose - grad_output = grad_output.transpose(0, 1).contiguous() - step_seqlen = ctx.seqlen // gpc.get_world_size(ParallelMode.TENSOR) - sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1), :] +# TODO: ops是否需要实现更加统一的形式 +def new_cross_entropy( + op_type: str = "py_vocab_parallel", + ignore_index: int = -100, + label_smoothing: float = 0, + parallel_output: bool = False, + inplace_backward: bool = True, + reduction: str = "none", +): + try: + op_type = cross_entropy_op_name_map[op_type] + except KeyError: + raise KeyError(f"op_type only support: {cross_entropy_op_name_map.keys()}") - grad_input = softmax - grad_2d = grad_input.view(-1, ctx.vocab_size) - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: + assert op_type in [ + CrossEntropyOpType.torch_naive, + CrossEntropyOpType.py_vocab_parallel, + ], "no-GPU env only support 'torch_naive' or 'py_vocab_parallel loss function" - grad_2d[arange_1d, target.view(-1)] -= 1 - grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + if op_type == CrossEntropyOpType.torch_naive: - # transpose - grad_input = grad_input.transpose(0, 1).contiguous() - # reshape - grad_input = grad_input.view(-1, gpc.config.model.vocab_size) + assert parallel_output is False, ( + "'torch_naive' (nn.CrossEntropyLoss) don't support parallel_output, " + "try use 'flash_vocab_parallel' or 'py_vocab_parallel'" + ) - return grad_input, None, None + return nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing, ignore_index=ignore_index) + elif op_type == CrossEntropyOpType.flash_vocab_parallel: -def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): - return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) + assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None." + try: + from flash_attn.losses.cross_entropy import ( + CrossEntropyLoss as FlashCrossEntropyLoss, + ) -def average_losses_across_data_parallel_group(losses): - """Reduce a tensor of losses across all GPUs.""" - averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) - torch.distributed.all_reduce(averaged_losses, group=gpc.get_group(ParallelMode.DATA)) - averaged_losses = averaged_losses / gpc.get_world_size(ParallelMode.DATA) + flash_cross_entropy_impl = True + except (ModuleNotFoundError, ImportError): + flash_cross_entropy_impl = False - return averaged_losses + assert ( + gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl + ), "Only flash cross entropy support parallel_output" + assert ( + internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + ), "flash cross entropy only support gpu backend" -class VocabSequenceParallelCrossEntropyLoss(nn.Module): - """ - Cross Entropy module for isp. - """ - - def __init__( - self, - ignore_index: int = -100, - reduction: str = "mean", - label_smoothing: float = 0, - process_group=None, - ): - super().__init__() - if reduction not in ["mean", "none"]: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.process_group = process_group - - def loss_mean_func(self, output_tensor): - losses = output_tensor.float() - loss = torch.sum(losses.view(-1)) / losses.numel() - - # TODO: allreduce loss in dp group - - return loss - - def forward(self, _input, target): - assert _input.is_cuda and target.is_cuda - - _loss_list = vocab_sequence_parallel_cross_entropy(_input, target, self.label_smoothing) - - if self.reduction == "mean": - loss = self.loss_mean_func(_loss_list) - return loss - - return _loss_list.view(-1) - - -class _VocabParallelCrossEntropy(torch.autograd.Function): - """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py - Supports vocab parallel loss calculation, but does not support inplace backward. - NOTE: This class is different from the original Apex implementation. Apex will calculate the loss of - ignore_index and flashCrossEntropy will set it to 0. InterEvo adapts the second approach. - """ - - @staticmethod - @internlm_accelerator.amp.custom_fwd - def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, process_group=None): - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - if process_group is not None and dist.get_world_size(process_group) > 1: - torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - # get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] - if process_group is not None and dist.get_world_size(process_group) > 1: - rank = dist.get_rank(process_group) - # world_size = dist.get_world_size(process_group) - part_len = vocab_parallel_logits.shape[-1] - vocab_start_index, vocab_end_index = part_len * rank, part_len * (rank + 1) - else: - vocab_start_index, vocab_end_index = 0, vocab_parallel_logits.shape[-1] - - # vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - ignore_mask = target == -100 - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - - # All reduce is needed to get the chunks from other GPUs. - if process_group is not None and dist.get_world_size(process_group) > 1: - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - - if process_group is not None and dist.get_world_size(process_group) > 1: - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Normalize and optionally smooth logits - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - - # Loss = log(sum(exp(logits))) - predicted-logit. - sum_exp_logits = torch.log(sum_exp_logits) - loss = sum_exp_logits - predicted_logits - loss[ignore_mask] = 0.0 - - vocab_size = exp_logits.size(-1) - if label_smoothing > 0: - r""" - We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. - = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) - = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i - = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K - From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py - """ - assert 1.0 > label_smoothing > 0.0 - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - - # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. - log_probs = torch.log(exp_logits) - mean_log_probs = log_probs.mean(dim=-1) - loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs - - ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size - # Store softmax, target-mask and masked-target for backward pass. - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, ignore_mask) - - return loss - - @staticmethod - @internlm_accelerator.amp.custom_bwd - def backward(ctx, grad_output): - - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d, ignore_mask = ctx.saved_tensors - label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size - - # All the inputs have softmax as thier gradient. - grad_input = softmax # s_{k} - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - - softmax_update = 1.0 - target_mask.view(-1).float() - - if label_smoothing > 0: - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update - average_grad = 1 / vocab_size - grad_2d[arange_1d, :] -= smoothing * average_grad - else: - grad_2d[arange_1d, masked_target_1d] -= softmax_update - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - grad_input[ignore_mask] = 0.0 # set ignore token loss as 0. - - return grad_input, None, None, None - - -class CrossEntropyApexVocabParallel(nn.Module): - """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py - Supports vocab parallel loss calculation, but does not support inplace backward. - """ - - def __init__( - self, ignore_index=-100, reduction="mean", label_smoothing=0.0, process_group=None, inplace_backward=False - ): - super().__init__() - if reduction not in ["mean", "none"]: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - assert inplace_backward is False, "does not support inplace backward" - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.process_group = process_group - - def forward(self, vocab_parallel_logits, target): - # assert vocab_parallel_logits.is_cuda and vocab_parallel_logits.is_cuda - - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, self.label_smoothing, self.process_group) - if self.reduction == "mean": - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss - - -def flash_loss( - ignore_index=-100, - reduction="mean", - label_smoothing=0.0, - process_group=None, - inplace_backward=False, # pylint:disable=W0613 -): - try: - from flash_attn.losses.cross_entropy import ( - CrossEntropyLoss as FlashCrossEntropyLoss, + logger.warning( + "You are using flash_attn cross_entropy operators, \ + which may result loss divergency in long sequence." ) - flash_cross_entropy_impl = True - except (ModuleNotFoundError, ImportError): - flash_cross_entropy_impl = False - - assert ( - gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl - ), "Only flash cross entropy support parallel_output" - - assert ( - internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU - ), "flash cross entropy only support gpu backend" + return FlashCrossEntropyLoss( + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + process_group=gpc.get_group(ParallelMode.TENSOR), + inplace_backward=inplace_backward, + ) - return FlashCrossEntropyLoss( - ignore_index=ignore_index, - reduction=reduction, - label_smoothing=label_smoothing, - process_group=process_group, - ) + elif op_type == CrossEntropyOpType.apex_naive: + assert parallel_output is False, ( + "'apex_naive' (nn.CrossEntropyLoss) can'ts support parallel_output," + "try use 'flash_vocab_parallel' or 'py_vocab_parallel'" + ) + return CrossEntropyLossApex( + ignore_index=ignore_index, + reduction=reduction, + inplace_backward=inplace_backward, + label_smoothing=label_smoothing, + ) -# TODO: ops是否需要实现更加统一的形式 -def new_cross_entropy( - ignore_index: int = -100, - reduction: str = "mean", - label_smoothing: float = 0, - parallel_output: bool = False, - **kwargs, -): - # if is_using_isp() and parallel_output: - # if gpc.is_rank_for_log(): - # logger.warning("Use VocabSequenceParallelCrossEntropyLoss.") - # return VocabSequenceParallelCrossEntropyLoss( - # ignore_index=ignore_index, - # reduction=reduction, - # label_smoothing=label_smoothing, - # process_group=gpc.get_group(ParallelMode.TENSOR), - # ) - - if parallel_output: - # return flash_loss( - # ignore_index=ignore_index, - # reduction=reduction, - # label_smoothing=label_smoothing, - # process_group=gpc.get_group(ParallelMode.TENSOR), - # ) + elif op_type == CrossEntropyOpType.py_vocab_parallel: + assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None." return CrossEntropyApexVocabParallel( ignore_index=ignore_index, @@ -361,13 +138,13 @@ def new_cross_entropy( label_smoothing=label_smoothing, process_group=gpc.get_group(ParallelMode.TENSOR), ) - else: - if gpc.is_rank_for_log(): - logger.warning( - "Use nn.CrossEntropyLoss rather than flashattn CrossEntropyLoss." - "parallel_output must be set false. Please note this!" - ) - kwargs.pop("inplace_backward", None) - return nn.CrossEntropyLoss( - ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, **kwargs + + elif op_type == CrossEntropyOpType.py_naive: + assert parallel_output is False, ( + "'py_naive' (nn.CrossEntropyLoss) don't support parallel_output," + "try use 'flash_vocab_parallel' or 'py_vocab_parallel'" ) + return CrossEntropyPython(ignore_index=ignore_index, reduction=reduction) + + else: + raise RuntimeError(f"unkown loss function type: {op_type}") diff --git a/internlm/model/ops/cross_entropy_ops/__init__.py b/internlm/model/ops/cross_entropy_ops/__init__.py new file mode 100644 index 00000000..1f4b6630 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/__init__.py @@ -0,0 +1,11 @@ +from .apex_naive_loss import CrossEntropyLossApex +from .py_naive_loss import CrossEntropyPython +from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel +from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss + +__all__ = [ + "CrossEntropyLossApex", + "CrossEntropyPython", + "CrossEntropyApexVocabParallel", + "VocabSequenceParallelCrossEntropyLoss", +] diff --git a/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py b/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py new file mode 100644 index 00000000..139f20a2 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py @@ -0,0 +1,77 @@ +import torch +from torch import nn + +from internlm.accelerator import get_accelerator + +try: + import xentropy_cuda_lib +except (ImportError, ModuleNotFoundError): + has_xentropy_cuda_lib = False +else: + has_xentropy_cuda_lib = True + + +internlm_accelerator = get_accelerator() + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + """ + Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py + Inplace backward is supported, but loss calculation of vocab parallel is not supported. + NOTE: it should be noted that when the pack_length exceeds 40K, the loss will not decrease. + """ + + @staticmethod + @internlm_accelerator.amp.custom_fwd + def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False): + losses, max_log_sum_exp = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels == padding_idx, 0) + ctx.save_for_backward(logits, max_log_sum_exp, labels) + ctx.smoothing = smoothing + ctx.padding_idx = padding_idx + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + @internlm_accelerator.amp.custom_bwd + def backward(ctx, grad_loss): + logits, max_log_sum_exp, labels = ctx.saved_tensors + if not grad_loss.is_contiguous(): + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels == ctx.padding_idx, 0) + grad_logits = xentropy_cuda_lib.backward( + grad_loss, logits, max_log_sum_exp, labels, ctx.smoothing, ctx.inplace_backward + ) + return grad_logits, None, None, None, None + + +class CrossEntropyLossApex(nn.Module): + """ + Inplace backward is supported, but loss calculation of vocab parallel is not supported. + NOTE: it should be noted that when the pack_length exceeds 40K, the loss will not decrease. + """ + + def __init__(self, ignore_index=-100, reduction="mean", label_smoothing=0.0, inplace_backward=False): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + + assert ( + has_xentropy_cuda_lib is True + ), "The 'xentropy_cuda_lib' package which CrossEntropyLossApex needed was not found in your environment!" + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, logits, target): + # assert logits.is_cuda and target.is_cuda + + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + logits, target, self.label_smoothing, self.ignore_index, self.inplace_backward + ) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/internlm/model/ops/cross_entropy_ops/py_naive_loss.py b/internlm/model/ops/cross_entropy_ops/py_naive_loss.py new file mode 100644 index 00000000..f391933f --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/py_naive_loss.py @@ -0,0 +1,83 @@ +import torch +from torch import nn + +from internlm.accelerator import get_accelerator + +internlm_accelerator = get_accelerator() + + +class CrossEntropyWriteInPython(torch.autograd.Function): + """baseline for unit test.""" + + @staticmethod + @internlm_accelerator.amp.custom_fwd + def forward(ctx, logits, target, ignore_idx): + # (1) cal mask + ignore_mask = target == ignore_idx + target[ignore_mask] = 0 + + # (2) safe softmax for logist + logits_max = torch.max(logits, dim=-1)[0] + logits = logits - logits_max.unsqueeze(dim=-1) + + # (3) cal predicted_logits + vocab_size = logits.shape[-1] + logits_2d = logits.view(-1, vocab_size) + target = target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits = logits_2d[arange_1d, target].clone().contiguous().view_as(target) + + # (4) softmax + exp_logits = logits + torch.exp(logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + # (5) Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + # (6) cal log + sum_exp_logits = torch.log(sum_exp_logits) + + # (7) cal loss + loss = sum_exp_logits - predicted_logits + + # (8) apply ignore_mask + loss[ignore_mask] = 0.0 + ctx.save_for_backward(exp_logits, target, ignore_mask) + return loss + + @staticmethod + @internlm_accelerator.amp.custom_bwd + def backward(ctx, grad_output): + # The deriving of cross entropy ref: + # https://shivammehta25.github.io/posts/deriving-categorical-cross-entropy-and-softmax/ + softmax, target, ignore_mask = ctx.saved_tensors + + # Add the gradient from matching classes(which is indicate by target). + grad_input = softmax + grad_2d = grad_input.view(-1, softmax.shape[-1]) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, target] -= 1.0 + + grad_input.mul_(grad_output.unsqueeze(dim=-1)) # elementwise multiplication + grad_input[ignore_mask] = 0.0 # set ignore token loss as 0. + + return grad_input, None, None, None + + +class CrossEntropyPython(nn.Module): + """ + Baseline for unit test. Please do not use this class directly. + """ + + def __init__(self, ignore_index=-100, reduction="mean"): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction + + def forward(self, logits, target): + loss = CrossEntropyWriteInPython.apply(logits, target, self.ignore_index) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py b/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py new file mode 100644 index 00000000..6f5457c8 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py @@ -0,0 +1,160 @@ +import torch +import torch.distributed as dist +from torch import nn + +from internlm.accelerator import get_accelerator + +internlm_accelerator = get_accelerator() + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py + Supports vocab parallel loss calculation, but does not support inplace backward. + NOTE: This class is different from the original Apex implementation. Apex will calculate the loss of + ignore_index and flashCrossEntropy will set it to 0. InterEvo adapts the second approach. + """ + + @staticmethod + @internlm_accelerator.amp.custom_fwd + def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, process_group=None): + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + if process_group is not None and dist.get_world_size(process_group) > 1: + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) + + # Get the partition's vocab indecies + # get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + if process_group is not None and dist.get_world_size(process_group) > 1: + rank = dist.get_rank(process_group) + # world_size = dist.get_world_size(process_group) + part_len = vocab_parallel_logits.shape[-1] + vocab_start_index, vocab_end_index = part_len * rank, part_len * (rank + 1) + else: + vocab_start_index, vocab_end_index = 0, vocab_parallel_logits.shape[-1] + + # vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + ignore_mask = target == -100 + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # All reduce is needed to get the chunks from other GPUs. + if process_group is not None and dist.get_world_size(process_group) > 1: + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + if process_group is not None and dist.get_world_size(process_group) > 1: + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + sum_exp_logits = torch.log(sum_exp_logits) + loss = sum_exp_logits - predicted_logits + loss[ignore_mask] = 0.0 + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + r""" + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, ignore_mask) + + return loss + + @staticmethod + @internlm_accelerator.amp.custom_bwd + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d, ignore_mask = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax # s_{k} + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + grad_input[ignore_mask] = 0.0 # set ignore token loss as 0. + + return grad_input, None, None, None + + +class CrossEntropyApexVocabParallel(nn.Module): + """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py + Supports vocab parallel loss calculation, but does not support inplace backward. + """ + + def __init__( + self, ignore_index=-100, reduction="mean", label_smoothing=0.0, process_group=None, inplace_backward=False + ): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + assert inplace_backward is False, "does not support inplace backward" + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.process_group = process_group + + def forward(self, vocab_parallel_logits, target): + # assert vocab_parallel_logits.is_cuda and vocab_parallel_logits.is_cuda + + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, self.label_smoothing, self.process_group) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py b/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py new file mode 100644 index 00000000..2072944f --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py @@ -0,0 +1,121 @@ +import torch +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + + +# Adapted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/ \ +# sequence_parallel/cross_entropy.py +class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): + """ + Cross Entropy module for isp. + """ + + @staticmethod + def forward(ctx, vocab_seq_parallel_logits, target, reduction, label_smoothing=0.0): # pylint: disable=W0613 + sp_size = gpc.get_world_size(ParallelMode.TENSOR) + + # reshape + # vocab_seq_parallel_logits: [B * (S/P), V] -> [B, S/P, V] + # target: [B * S/P] -> [B, S/P] + bsz = gpc.config.data.micro_bsz if gpc.config.data.use_packed_dataset is False else 1 + vocab_seq_parallel_logits = vocab_seq_parallel_logits.view(bsz, -1, gpc.config.model.vocab_size) + target = target.view(bsz, -1) + + # transpose + # vocab_seq_parallel_logits: [B, S/P, V] -> [S/P, B, V] + # target: [B, S/P] -> [S/P, B] + # return: [S, B] + vocab_seq_parallel_logits = vocab_seq_parallel_logits.transpose(0, 1).contiguous() + target = target.transpose(0, 1).contiguous() + + ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_size + batch_size = vocab_seq_parallel_logits.size(1) + + # Need softmax for backward + softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) + ctx.vocab_size = vocab_seq_parallel_logits.size(2) + loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction="none") + + loss_all = torch.empty( + ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device + ) + + torch.distributed.all_gather_into_tensor(loss_all, loss, group=gpc.get_group(ParallelMode.TENSOR)) + + # [s b] => [b, s] + loss_all = loss_all.transpose(0, 1).contiguous() + + ctx.save_for_backward(softmax, target) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + softmax, target = ctx.saved_tensors + + # transpose + grad_output = grad_output.transpose(0, 1).contiguous() + + step_seqlen = ctx.seqlen // gpc.get_world_size(ParallelMode.TENSOR) + sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1), :] + + grad_input = softmax + grad_2d = grad_input.view(-1, ctx.vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + grad_2d[arange_1d, target.view(-1)] -= 1 + grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + + # transpose + grad_input = grad_input.transpose(0, 1).contiguous() + # reshape + grad_input = grad_input.view(-1, gpc.config.model.vocab_size) + + return grad_input, None, None + + +def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): + return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) + + +class VocabSequenceParallelCrossEntropyLoss(nn.Module): + """ + Cross Entropy module for isp. + """ + + def __init__( + self, + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0, + process_group=None, + ): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.process_group = process_group + + def loss_mean_func(self, output_tensor): + losses = output_tensor.float() + loss = torch.sum(losses.view(-1)) / losses.numel() + + # TODO: allreduce loss in dp group + + return loss + + def forward(self, _input, target): + assert _input.is_cuda and target.is_cuda + + _loss_list = vocab_sequence_parallel_cross_entropy(_input, target, self.label_smoothing) + + if self.reduction == "mean": + loss = self.loss_mean_func(_loss_list) + return loss + + return _loss_list.view(-1) diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py index 3ccbfb54..537a4077 100644 --- a/tests/test_infer/test_trainer_generate.py +++ b/tests/test_infer/test_trainer_generate.py @@ -10,7 +10,7 @@ from internlm.core.trainer import TrainState, Trainer # noqa: E402 from internlm.data import build_train_loader_with_data_type # noqa: E402 from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.model.losses import FlashGPTLMLoss # noqa: E402 +from internlm.model.losses import InternLoss # noqa: E402 from internlm.train import ( # noqa: E402 get_scheduler_hooks, initialize_model, @@ -25,7 +25,7 @@ def setup_generator(config, tokenizer): model = initialize_model() isp_communicator = initialize_parallel_communicator(model) - criterion = FlashGPTLMLoss() + criterion = InternLoss() # initialize the train data loader train_dl, _ = build_train_loader_with_data_type() diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index 48b97bfa..ba7f0118 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -15,7 +15,7 @@ from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize.launch import args_sanity_check -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( initialize_model, @@ -175,7 +175,7 @@ def train_check_output(args): _ = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index e6890517..ddbb24a0 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -38,7 +38,7 @@ args_sanity_check, ) from internlm.model.losses import ( # noqa: E402 #pylint: disable=wrong-import-position - FlashGPTLMLoss, + InternLoss, ) from internlm.model.metrics import ( # noqa: E402 #pylint: disable=wrong-import-position AccPerplex, @@ -224,7 +224,7 @@ def train_model(args): _ = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 2fd8ad4c..8b506d2d 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -13,7 +13,7 @@ from internlm.core.trainer import Trainer, TrainState from internlm.data import build_train_loader_with_data_type from internlm.initialize import initialize_distributed_env -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.train import ( get_scheduler_hooks, initialize_model, @@ -174,7 +174,7 @@ def train( isp_communicator = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) + criterion = InternLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) # initialize the train data loader train_dl, _ = build_train_loader_with_data_type() diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index f142e503..5f0782b4 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -8,7 +8,7 @@ from internlm.core.context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, @@ -58,7 +58,7 @@ def train_check(args): isp_communicator = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 0fd24926..990b334a 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -11,7 +11,7 @@ from internlm.core.context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, @@ -78,7 +78,7 @@ def train_check_norm_weight(args): isp_communicator = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 4fa096a5..13c01b1c 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -21,7 +21,7 @@ ) from internlm.eval.evaluation import switch_evaluation_mode from internlm.initialize.launch import args_sanity_check -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( initialize_model, @@ -275,7 +275,7 @@ def exam_loss(args): _ = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 7926bae5..c7da6f85 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -27,7 +27,7 @@ ) from internlm.eval.evaluation import evaluate_on_val_dls # noqa: E402 from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.model.losses import FlashGPTLMLoss # noqa: E402 +from internlm.model.losses import InternLoss # noqa: E402 from internlm.model.metrics import AccPerplex, SchedulerMetricHook # noqa: E402 from internlm.monitor import ( # noqa: E402 initialize_monitor_manager, @@ -123,7 +123,7 @@ def main(args): config_lines = f.readlines() # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = build_train_loader_with_data_type() From e3f5001d0fc88b939bb73731feee624574b2f3d2 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 31 Dec 2024 14:22:53 +0800 Subject: [PATCH 12/12] feat(comm/attn_offload.py): support selective ckpt and cpu offload (#383) --- internlm/core/parallel/comm/__init__.py | 3 + internlm/core/parallel/comm/attn_offload.py | 127 ++++++++ internlm/core/parallel/comm/isp.py | 34 ++ internlm/core/trainer_builder.py | 4 + internlm/initialize/launch.py | 37 ++- internlm/model/ops/_flash_attn.py | 331 ++++++++++++++++++++ internlm/model/ops/attention.py | 23 +- internlm/train/pipeline.py | 1 + 8 files changed, 549 insertions(+), 11 deletions(-) create mode 100644 internlm/core/parallel/comm/attn_offload.py create mode 100644 internlm/model/ops/_flash_attn.py diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py index e69de29b..be170f28 100644 --- a/internlm/core/parallel/comm/__init__.py +++ b/internlm/core/parallel/comm/__init__.py @@ -0,0 +1,3 @@ +from .attn_offload import get_offload_manager, initialize_offload_manager + +__all__ = ["initialize_offload_manager", "get_offload_manager"] diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py new file mode 100644 index 00000000..da23f3ae --- /dev/null +++ b/internlm/core/parallel/comm/attn_offload.py @@ -0,0 +1,127 @@ +import torch + +from internlm.utils.common import get_current_device + +global_attn_offload = None + + +class AttnOffloadManager: + """ + A manager for attention output CPU offloading and GPU prefetch loading. + """ + + def __init__(self, enable_cpu_offload: bool = False) -> None: + # cpu offload overlapping + self.cpu_offload = enable_cpu_offload + # layer id mapping to flash attn output + self.fa_output_mapping = {} + self.fa_stream = torch.cuda.Stream() + self.d2h_final_event = torch.cuda.Event() + self.h2d_final_event = torch.cuda.Event() + # prepare for tensor buffer + self.tensor_id_to_tensor_bufs = {} + + def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): + """Get tensor buffer for offloaded tensor.""" + layer_id = layer_id % 2 + if layer_id not in self.tensor_id_to_tensor_bufs: + self.tensor_id_to_tensor_bufs[layer_id] = {} + + if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]: + allocate_new_buf = True + else: + tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype + + if allocate_new_buf: + # supposed to only execute once + buffer = torch.empty( + tensor.size(), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + ) + + self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer + + return self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + + def insert_fa_output_with_layer(self, layer_idx, output): + assert layer_idx not in self.fa_output_mapping + if self.cpu_offload is False: + self.fa_output_mapping[layer_idx] = output + return + + tensors = [] + for tensor_id, tensor in enumerate(output): + if tensor is None: + tensors.append(None) + continue + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id) + tensor_buf.copy_(tensor) + tensors.append(tensor_buf) + self.fa_output_mapping[layer_idx] = tensors + + def get_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + return self.fa_output_mapping.pop(layer_idx) + + def offload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.d2h_final_event) + + with torch.cuda.stream(self.fa_stream): + _gpu_tensors = self.fa_output_mapping.pop(layer_idx) + _cpu_tensors = [] + for _tensor in _gpu_tensors: + if _tensor is None: + _cpu_tensors.append(_tensor) + continue + + _cpu_backup = torch.empty( + _tensor.size(), + dtype=_tensor.dtype, + layout=_tensor.layout, + device="cpu", + pin_memory=True, + ) + _cpu_backup.copy_(_tensor, non_blocking=True) + _cpu_tensors.append(_cpu_backup) + + # _cpu_tensors.append(_tensor.to("cpu", non_blocking=False)) + + self.fa_output_mapping[layer_idx] = _cpu_tensors + + self.fa_stream.record_event(self.d2h_final_event) + + def preload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.h2d_final_event) + + # Important: get device before with stream, in stream get device is error + _device = get_current_device() + with torch.cuda.stream(self.fa_stream): + _cpu_tensors = self.fa_output_mapping.pop(layer_idx) + self.fa_output_mapping[layer_idx] = [ + _tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor + for _tensor in _cpu_tensors + ] + + self.fa_stream.record_event(self.h2d_final_event) + + +def initialize_offload_manager(enable_cpu_offload: bool = False): + global global_attn_offload + if global_attn_offload is None: + global_attn_offload = AttnOffloadManager(enable_cpu_offload) + + return global_attn_offload + + +def get_offload_manager(): + assert global_attn_offload is not None + return global_attn_offload diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 7e722c2f..24677c09 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -37,6 +37,8 @@ params_dispatch_with_condition, ) +from .attn_offload import get_offload_manager + # not really useful, only for code hint. class WPCommunicator(ABC): @@ -306,6 +308,7 @@ def __init__( overlap: bool = False, process_group: dist.ProcessGroup = None, is_moe: bool = False, + selective_ckpt_offload: bool = False, ) -> None: self.process_group = process_group self.overlap = overlap @@ -316,6 +319,14 @@ def __init__( self._forward_prefetch_prerequisites = [] self._forward_overlap_per = self._get_forward_overlap_granularity() self._launch_before_module = self._get_launch_before_module() + # As an optimization, do not release weight after forward for the last + # transformer block since wp would prefetch it immediately + self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1] + self.layers_fa_not_release = [ + gpc.config.isp_num_layers - 1, + int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1, + ] + self.sc_offload = selective_ckpt_offload # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -411,6 +422,7 @@ def is_allgather_launch_module(name, module): self._overlap_states[cid].index_to_isp_modules[idx].append(child) setattr(child, "isp_name", name) + setattr(child, "isp_layer_idx", idx) full_name = f"{cid}.{idx}.{name}" setattr( @@ -506,6 +518,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args) if block_index + 1 < self._num_blocks: self._all_gather_block_weight(block_index + 1) + # register offload and prefetch hook for selective ckpt with wo linear + if self.sc_offload is True: + # move current layer's attn output from GPU to CPU asynchronizely + if ( + self.is_forward is True + and gpc.config.selective_checkpoint + and block_index not in self.layers_fa_not_release + and block_index < self._ckpt_block_num + ): + get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index) + + # load previous layer's attn output from CPU to GPU asynchronizely + if ( + self.is_forward is False + and gpc.config.selective_checkpoint + and (0 <= (block_index - 1) < self._ckpt_block_num) + ): + get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1) + def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: self._all_gather_module_weight(module) @@ -539,6 +570,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._all_gather_module_weight(next_module) def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 + if int(module.isp_layer_idx) in self.layers_wp_not_release: + # print(f"the layer {module.isp_layer_idx} after forward not clear weight") + return if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_handle(module) self._clear_weight(module) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 2b82bc1f..71c30d00 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -11,6 +11,7 @@ from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.parallel.comm import initialize_offload_manager from internlm.core.trainer import Trainer from internlm.data.streaming.utils import streaming_simple_resume from internlm.data.train_state import get_train_state @@ -118,6 +119,9 @@ def __init__( # initialize isp communicator isp_communicator = initialize_parallel_communicator(model) + # initialize cpu offload manager for selective checkpoint + initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + # initialize train state train_state = get_train_state(train_dl) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index c8b16516..b9e8e41b 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -66,6 +66,8 @@ def get_default_parser(): def args_sanity_check(): assert gpc.config is not None, "config is not load!" + gpc.is_forward = True + if "JOB_NAME" not in gpc.config: gpc.config._add_item("JOB_NAME", "AnonymousJob") @@ -73,6 +75,13 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + if gpc.config.model_type == "InternLM3_M": + # TODO: need check for isp overlap + num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers + else: + num_layers = gpc.config.model.num_layers + gpc.config.isp_num_layers = num_layers + if "use_apex_adam" not in gpc.config: gpc.config._add_item("use_apex_adam", False) @@ -388,17 +397,18 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0 - ), "VOCAB_SIZE must be integer multiple of tensor parallel size" if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" assert ( torch.__version__ >= "2.1.0" ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0 - ), "VOCAB_SIZE must be integer multiple of wp size" + + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0 + ), "model.vocab_size must be integer multiple of weight parallel size" + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0 + ), "model.vocab_size must be integer multiple of tensor parallel size" assert gpc.config.parallel["tensor"].get("mode", None) in [ TensorParallelMode.mtp.name, @@ -524,7 +534,20 @@ def args_sanity_check(): gpc.config.loss._add_item("moe_loss_coeff", 1.0) if "selective_checkpoint" not in gpc.config: - gpc.config._add_item("selective_checkpoint", False) + gpc.config.selective_checkpoint = False + if "selective_checkpoint_offload" not in gpc.config: + gpc.config.selective_checkpoint_offload = False + if gpc.config.selective_checkpoint is True: + assert ( + gpc.config.parallel["tensor"]["mode"] == "isp" + ), "When using selective_checkpoint, tensor parallel mode must be isp" + if gpc.config.selective_checkpoint_offload is True: + assert ( + gpc.config.selective_checkpoint is True + ), "When using selective_checkpoint_offload, selective_checkpoint must be True" + assert ( + gpc.config.parallel.weight.launch_allgather_before == "wo" + ), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module" # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py new file mode 100644 index 00000000..87aac2eb --- /dev/null +++ b/internlm/model/ops/_flash_attn.py @@ -0,0 +1,331 @@ +# Copyright (c) InternLM. All rights reserved. +import torch + +from internlm.accelerator import get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import get_offload_manager + +try: + import flash_attn + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward, + _flash_attn_varlen_forward, + ) + + gpu_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + gpu_flash_attn_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + + +class FlashAttnVarlenKVPackedFunc_V263(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.6.3. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + + if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + ( + out, + q, + k, + v, + out_padded, + softmax_lse, + S_dmask, + rng_state, + ) = _flash_attn_varlen_forward( # pylint: disable=E1123 + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( # pylint: disable=E1121,E1124 + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc_V221(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.2.1. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + + if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax and dropout_p > 0, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + layer_idx=0, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + + assert gpu_flash_attn_impl is True and flash_attn.__version__ in [ + "2.2.1", + "2.6.3", + ], "flash-attn should be installed and version must be v2.2.1 or v2.6.3" + + if flash_attn.__version__ == "2.2.1": + return FlashAttnVarlenKVPackedFunc_V221.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_attn_probs, + layer_idx, + ) + + return FlashAttnVarlenKVPackedFunc_V263.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + layer_idx, + ) diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 604ea77a..3aec51f5 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -93,13 +93,14 @@ from flash_attn.flash_attn_interface import ( flash_attn_varlen_func as _flash_varlen_qkvsplited_func, ) - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, - ) from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func as _flash_varlen_qkvpacked_func, ) + from ._flash_attn import ( + flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, + ) + gpu_flash_attn_impl = True except (ModuleNotFoundError, ImportError): gpu_flash_attn_impl = False @@ -187,6 +188,7 @@ def _flash_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -204,6 +206,7 @@ def _flash_varlen_kvpacked_attn( dropout_p, softmax_scale, causal, + layer_idx=layer_idx, ) return output.unsqueeze(dim=0) @@ -521,6 +524,7 @@ def _npu_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # TODO: support npu native varlen flash attention k, v = kv.unbind(dim=2) @@ -579,6 +583,7 @@ def _deeplink_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -1012,7 +1017,17 @@ def _q_kv_with_cu_seqlens( extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( - q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout, + softmax_scale, + causal, + *extra_args, + layer_idx=self.layer_idx, ) @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 79e9caf4..ca11e689 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -363,6 +363,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.overlap, gpc.get_group(ParallelMode.WEIGHT), is_moe=False, + selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), ) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator)