Skip to content

Commit

Permalink
Ck tile/layernorm: implement naive reduce, opt performance (#1784)
Browse files Browse the repository at this point in the history
* add no welford

* enable output raw

* raw of int8

* fix build

* fix smoke test err

* [ck_tile]layernorm: fix welford ok, set int8 and bf16 small N as default and others open by generate

* [cktile]layernorm, fix err commit files and remove uselss

* fix quant 8192 err & change norm_reduce class and file name

---------

Co-authored-by: coderfeli <[email protected]>
Co-authored-by: carlushuang <[email protected]>
  • Loading branch information
3 people authored Jan 3, 2025
1 parent 17e8efb commit 4bc6104
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 170 deletions.
160 changes: 93 additions & 67 deletions example/ck_tile/02_layernorm2d/generate.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion example/ck_tile/02_layernorm2d/script/smoke_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"

namespace ck_tile {

Expand Down Expand Up @@ -43,56 +43,59 @@ struct Layernorm2dFwdPipelineDefaultPolicy
}

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;

return BlockWelford<P_>{};
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;
return BlockNormReduce<P_>{};
}

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;

return BlockWelfordSync<P_>{};
return BlockNormReduceSync<P_>{};
}

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync()
CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync()
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;

return BlockWelfordCrossWarpSync<P_>{};
return BlockNormReduceCrossWarpSync<P_>{};
}

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv>;
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>;

using block_welford = BlockWelford<P_>;
using block_welford = BlockNormReduce<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using mean_var_block_tile =
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());

return GetBlockWelfordCrossWarpSync<Problem>()
return GetBlockNormReduceCrossWarpSync<Problem>()
.template GetSmemSize<mean_var_block_tile>();
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;

Expand Down Expand Up @@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
int cur_count = 0;
int max_count =
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(row_size);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();

auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();

using XTensorType = decltype(cast_tile<ComputeDataType>(x));
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
clear_tile(mean);
clear_tile(var);
// load gamma/beta (TODO: support no gamma/beta?)
const auto gamma = load_tile(gamma_window);
const auto beta = load_tile(beta_window);
Expand All @@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}

// compute welford each-thread->cross-lane->cross-warp
auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});

// compute reduce each-thread->cross-lane->cross-warp
block_norm_reduce(acc, mean, var, cur_count, max_count);
block_norm_reduce_sync(mean, var, cur_count);
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
if(kWelford)
{
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
}
else
{
sweep_tile(mean, [&](auto idx) {
mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
});
}
// compute inv-std
auto inv_std = tile_elementwise_in(
[&](const auto& v_) {
Expand Down Expand Up @@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);

auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;

ln(idx) = ln_;
ln(idx) = ln_;
});

if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr bool kWelford = Problem::Traits::kWelford;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;

Expand Down Expand Up @@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
void* smem,
Epilogue) const
{
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
Expand All @@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
int max_count =
(num_n_tile_iteration - 1) * count_per_iter +
block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
auto block_welford = Policy::template GetBlockWelford<Problem>();
auto block_welford_sync = Policy::template GetBlockWelfordSync<Problem>();
auto block_welford_cross_warp_sync =
Policy::template GetBlockWelfordCrossWarpSync<Problem>();
auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
auto block_norm_reduce_cross_warp_sync =
Policy::template GetBlockNormReduceCrossWarpSync<Problem>();

using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto mean = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_welford.template MakeMeanVarBlockTile<XTensorType>();
auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();

for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
Expand All @@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_welford(acc, mean, var, cur_count, max_count);
block_norm_reduce(acc, mean, var, cur_count, max_count);
}

block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_norm_reduce_sync(mean, var, cur_count);
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});

// compute inv-std
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template <bool kPadN_,
bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kWelford_,
bool kTwoPass_,
Layernorm2dFusedAddEnum kFusedAdd_,
Layernorm2dFusedQuantEnum kFusedQuant_>
Expand All @@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#pragma once

#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Loading

0 comments on commit 4bc6104

Please sign in to comment.