Skip to content

Commit

Permalink
feat: update v5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Apr 15, 2022
1 parent ba6960e commit a44c381
Show file tree
Hide file tree
Showing 567 changed files with 300,741 additions and 16,230 deletions.
8 changes: 7 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
[submodule "3rdparty/Megatron-LM"]
path = 3rdparty/Megatron-LM
url = https://github.com/NVIDIA/Megatron-LM.git
branch = v2.4
branch = v2.6
[submodule "examples/tensorflow/bert/tensorflow_bert/bert"]
path = examples/tensorflow/bert/tensorflow_bert/bert
url = https://github.com/google-research/bert.git
[submodule "examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer"]
path = examples/pytorch/swin/Swin-Transformer-Quantization/SwinTransformer
url = https://github.com/microsoft/Swin-Transformer
[submodule "examples/pytorch/vit/ViT-quantization/ViT-pytorch"]
path = examples/pytorch/vit/ViT-quantization/ViT-pytorch
url = https://github.com/jeonsworld/ViT-pytorch
2 changes: 1 addition & 1 deletion 3rdparty/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/trt_fused_multihead_attention/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 6 additions & 0 deletions 3rdparty/trt_fused_multihead_attention/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
#define COMMON_CUH

#include "cublas_v2.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif

#define HDI inline __host__ __device__

Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

161 changes: 130 additions & 31 deletions 3rdparty/trt_fused_multihead_attention/fused_multihead_attention_v2.h

Large diffs are not rendered by default.

188 changes: 188 additions & 0 deletions 3rdparty/trt_fused_multihead_attention/qkvToContext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,56 @@ public:
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
}

void setup(const int S, const int B, const int window_num)
{
// TODO these implementation details might be better centralized into the XMMA code, since they are needed in
// several places (also outside of this plugin)
size_t warps_m = 2, warps_n = 2, warps_k = 1;
if (S == 64 || S == 128) {
warps_m = 2;
warps_n = 2;
}
else if (S == 256) {
warps_m = 1;
warps_n = 4;
}
else if (S == 384) {
warps_m = 1;
warps_n = 8;
}
else {
assert(false && "Unsupporte seqlen");
}
// The number of threads per CTA.
threads_per_cta = warps_m * warps_n * warps_k * 32;
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
// The number of xmmas in the N dimension.
xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n);

const float scale_bmm1 = interface->mRsqrtHeadSize;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;

Data_type scale_type = DATA_TYPE_FP16;
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
set_alpha(params.scale_softmax, scale_softmax, scale_type);
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);

params.b = B;
params.h = interface->mNumHeads;
params.s = S;
params.d = interface->mHeadSize;
params.window_num = window_num;

// mLdQKV = 3 * B * mNumHeads * mHeadSize;
// mLdOut = B * mNumHeads * mHeadSize;

params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
params.packed_mask_stride_in_bytes = S * sizeof(half);
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
}

void run(const void* qkvPtr, const void* maskPtr, const void* cuSeqlenPtr, void* output, void* workspace, cudaStream_t stream)
{
params.qkv_ptr = const_cast<void*>(qkvPtr);
Expand All @@ -160,6 +210,29 @@ public:
check_cuda_error(cudaPeekAtLastError());
}

void run(const void* qkvPtr,
const void* maskPtr,
const void* relative_position_bias,
const int actual_seqlen,
void* output,
void* workspace,
cudaStream_t stream)
{
params.qkv_ptr = const_cast<void*>(qkvPtr);

params.packed_mask_ptr = const_cast<void*>(maskPtr);

params.packed_relative_position_bias_ptr = const_cast<void*>(relative_position_bias);

params.o_ptr = output;

params.actual_seqlen = actual_seqlen;

params.cu_seqlens = nullptr;
xmmaKernel->run(params, stream);
check_cuda_error(cudaPeekAtLastError());
}

bool isValid(int s) const
{
return xmmaKernel->isValid(s);
Expand Down Expand Up @@ -218,6 +291,12 @@ void FusedMHARunnerFP16v2::setup(const int S, const int B)
pimpl->setup(S, B);
}

void FusedMHARunnerFP16v2::setup(const int S, const int B, const int window_num)
{
MHARunner::setup(S, B, window_num);
pimpl->setup(S, B, window_num);
}

size_t FusedMHARunnerFP16v2::getWorkspaceSize() const
{
return 0;
Expand All @@ -233,6 +312,17 @@ void FusedMHARunnerFP16v2::run(const void* input, const void* mask, const void*
pimpl->run(input, mask, seqlen, output, workspace, stream);
}

void FusedMHARunnerFP16v2::run(const void* input,
const void* mask,
const void* relatice_position_bias,
const int actual_seqlen,
void* workspace,
void* output,
cudaStream_t stream)
{
pimpl->run(input, mask, relatice_position_bias, actual_seqlen, output, workspace, stream);
}

bool FusedMHARunnerFP16v2::isValid(int s) const
{
return pimpl->isValid(s);
Expand Down Expand Up @@ -273,6 +363,87 @@ public:
return interface->mB * xmmas_m * threads_per_cta * sizeof(uint32_t);
}

void setup(const int S, const int B, const int window_num)
{
size_t warps_m, warps_n, warps_k = 1;
if (S == 64) {
warps_m = 2;
warps_n = 2;
}
else if (S == 128) {
warps_m = 2;
warps_n = 2;
}
else if (S == 256) {
warps_m = 1;
warps_n = 4;
}
else if (S == 384) {
warps_m = 1;
warps_n = 8;
}
else {
assert(false && "Unsupported seqlen.");
}
// The number of threads per CTA.
threads_per_cta = warps_m * warps_n * warps_k * 32;
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
// The number of xmmas in the N dimension.
xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n);

params.b = B;
params.h = interface->mNumHeads;
params.s = S;
params.d = interface->mHeadSize;
params.window_num = window_num;
params.use_int8_scale_max = true;
params.packed_mask_stride_in_bytes = S * sizeof(half);
params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(int8_t);
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(int8_t);

float scaleQkv = interface->mScaleQkv;
float scaleCtx = interface->mScaleCtx;

// float scaleBmm1 = scaleQkv * scaleQkv;// * (1.f / sqrtf(interface->mHeadSize));
// float scaleBmm2 = interface->mDqProbs * scaleQkv / scaleCtx;
// float scaleSoftmax = 1.f / interface->mDqProbs;
//TODO: unify 3 scales
float scaleBmm1 = scaleQkv * (1.f / sqrtf(interface->mHeadSize));
float scaleBmm2 = scaleCtx;
float scaleSoftmax = interface->mDqProbs;

params.scale_bmm1 = reinterpret_cast<const uint32_t&>(scaleBmm1);
params.scale_bmm2 = reinterpret_cast<const uint32_t&>(scaleBmm2);
params.scale_softmax = reinterpret_cast<const uint32_t&>(scaleSoftmax);

params.enable_i2f_trick =
-double(1 << 22) * double(scaleBmm2) <= -128.f && double(1 << 22) * double(scaleBmm2) >= 127.f;
}

void run(const void* qkvPtr,
const void* maskPtr,
const void* relativePositionBiasPtr,
int actual_seqlen,
void* output,
void* workspace,
cudaStream_t stream)
{
params.qkv_ptr = const_cast<void*>(qkvPtr);

params.packed_mask_ptr = const_cast<void*>(maskPtr);

params.packed_relative_position_bias_ptr = const_cast<void*>(relativePositionBiasPtr);

params.actual_seqlen = actual_seqlen;

params.o_ptr = output;

params.cu_seqlens = nullptr;

xmmaKernel->run(params, stream);
}

void setup(const int S, const int B)
{
size_t warps_m, warps_n, warps_k = 1;
Expand Down Expand Up @@ -430,6 +601,11 @@ void FusedMHARunnerInt8v2::setScaleList(const float scaleQkv, const float dqProb
mScaleCtx = scaleCtx;
}

void FusedMHARunnerInt8v2::setup(const int S, const int B, const int window_num)
{
pimpl->setup(S, B, window_num);
}

void FusedMHARunnerInt8v2::setup(const int S, const int B)
{
pimpl->setup(S, B);
Expand All @@ -445,6 +621,18 @@ void FusedMHARunnerInt8v2::run(const void* input, const void* mask, void* worksp
assert(false && "Not implemented");
}

void FusedMHARunnerInt8v2::run(const void* input,
const void* mask,
const void* relatice_position_bias,
int actual_seqlen,
void* workspace,
void* output,
cudaStream_t stream)
{
pimpl->run(input, mask, relatice_position_bias, actual_seqlen, output, workspace, stream);
}


void FusedMHARunnerInt8v2::run(const void* input, const void* mask, const void* seqlen, void* workspace, void* output, cudaStream_t stream)
{
pimpl->run(input, mask, seqlen, output, workspace, stream);
Expand Down
10 changes: 10 additions & 0 deletions 3rdparty/trt_fused_multihead_attention/qkvToContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,14 @@ class MHARunner
mNumMats = B * mNumHeads;
}

virtual void setup(const int S, const int B, const int window_num)
{
setup(S, B);
}

virtual void run(const void* input, const void* mask, void* workspace, void* output, cudaStream_t stream) = 0;
virtual void run(const void* input, const void* mask, const void* seqlen, void* workspace, void* output, cudaStream_t stream) = 0;
virtual void run(const void* input, const void* mask, const void* relatice_position_bias, const int actual_seqlen, void* workspace, void* output, cudaStream_t stream) = 0;

virtual void setScaleList(const float scaleQkv, const float dqProbs, const float scaleCtx) = 0;

Expand Down Expand Up @@ -102,9 +108,11 @@ class FusedMHARunnerFP16v2 : public MHARunner
~FusedMHARunnerFP16v2() = default; // for pimpl

virtual void setup(const int S, const int B) override;
virtual void setup(const int S, const int B, const int window_num) override;

void run(const void* input, const void* mask, void* workspace, void* output, cudaStream_t stream);
void run(const void* input, const void* mask, const void* seqlen, void* workspace, void* output, cudaStream_t stream) override;
void run(const void* input, const void* mask, const void* relatice_position_bias, const int actual_seqlen, void* workspace, void* output, cudaStream_t stream) override;

void setScaleList(const float scaleQkv, const float dqProbs, const float scaleCtx) override;

Expand All @@ -129,9 +137,11 @@ class FusedMHARunnerInt8v2 : public MHARunner
void setScaleList(const float scaleQkv, const float dqProbs, const float scaleCtx);

virtual void setup(const int S, const int B) override;
virtual void setup(const int S, const int B, const int window_num) override;

void run(const void* input, const void* mask, void* workspace, void* output, cudaStream_t stream);
void run(const void* input, const void* mask, const void* seqlen, void* workspace, void* output, cudaStream_t stream) override;
void run(const void* input, const void* mask, const void* relatice_position_bias, const int actual_seqlen, void* workspace, void* output, cudaStream_t stream) override;

size_t getWorkspaceSize() const override;

Expand Down
Loading

0 comments on commit a44c381

Please sign in to comment.