Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerNormalization broadcast (limited support for axis=2) #23297

Merged
merged 11 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // no broadcast for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
76 changes: 76 additions & 0 deletions onnxruntime/core/providers/cpu/nn/layer_norm_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/tensor_shape.h"
#include "core/common/status.h"

namespace onnxruntime {

constexpr const char* kLayerNormInputShapeMismatchError =
"Size of scale and bias (if provided) must match X.shape[axis:], "
"or scale and bias (with same shape) can be broadcasted to X when axis is 2.";

constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";

class LayerNormHelper {
public:
static Status CheckBroadcast(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape& bias_shape,
bool has_bias,
int64_t axis,
int64_t& broadcast_param) {
broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
if (broadcast_param == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
}

return Status::OK();
}

private:
static int64_t GetBroadcastParam(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape* bias_shape,
int64_t axis) {
// X shape is (B, S, ...)
if (axis == 2 &&
x_shape.NumDimensions() >= 3 &&
x_shape.NumDimensions() == scale_shape.NumDimensions() &&
(bias_shape == nullptr || *bias_shape == scale_shape)) {
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) {
return 0;
}
}

if (x_shape.GetDims()[0] == scale_shape.GetDims()[0]) {
// scale and bias shape is (B, S, ...).
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
return 1;
}

// scale and bias shape is (B, 1, ...), returns S
if (scale_shape.GetDims()[1] == 1) {
return x_shape.GetDims()[1];
}
} else if (scale_shape.GetDims()[0] == 1) {
// scale and bias shape is (1, S, ...), returns -S
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
return -(x_shape.GetDims()[1]);
}
}
}

return 0;
}
};

} // namespace onnxruntime
78 changes: 52 additions & 26 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "layer_norm_impl.h"
#include "layer_norm_helper.h"

Check warning on line 5 in onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc:5: Include the directory when naming header files [build/include_subdir] [4]

#include "core/common/safeint.h"
#include "core/framework/tensor.h"
Expand All @@ -24,6 +25,7 @@
const T* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
const int64_t broadcast_param,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
Expand Down Expand Up @@ -55,13 +57,24 @@
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

for (int64_t h = 0; h < norm_size; h++) {
// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// Here we compute the initial index for scale and bias data.
int64_t i = (broadcast_param == 0)
? 0
: norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));

for (int64_t h = 0; h < norm_size; h++, i++) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * scale_data[h];
p_output[h] = p_output[h] / mean_square * scale_data[i];
} else if (nullptr == bias_data) {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h];
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h];
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i] + bias_data[i];
}
}

Expand All @@ -82,6 +95,7 @@
const MLFloat16* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
const int64_t broadcast_param,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
Expand Down Expand Up @@ -120,13 +134,24 @@
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

for (size_t h = 0; h < num_elems; h++) {
// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// Here we compute the initial index for scale and bias data.
int64_t i = (broadcast_param == 0)
? 0
: norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));

for (size_t h = 0; h < num_elems; h++, i++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i];
} else if (nullptr == bias_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i] + bias_float_ptr[i];
}
}

Expand Down Expand Up @@ -161,9 +186,7 @@
simplified_{simplified},
contrib_op_{contrib_op},
prepacked_scale_fp32_data_(nullptr),
prepacked_scale_fp32_size_(0),
prepacked_bias_fp32_data_(nullptr),
prepacked_bias_fp32_size_(0) {
prepacked_bias_fp32_data_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
}
Expand All @@ -179,8 +202,8 @@
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
size_t scale_size = scale ? static_cast<size_t>(scale->Shape().Size()) : prepacked_scale_fp32_size_;
size_t bias_size = bias ? static_cast<size_t>(bias->Shape().Size()) : prepacked_bias_fp32_size_;
const TensorShape& scale_shape = scale ? scale->Shape() : prepacked_scale_fp32_shape_;
const TensorShape& bias_shape = bias ? bias->Shape() : prepacked_bias_fp32_shape_;
Tensor* Y = p_ctx->Output(0, x_shape);
T* Y_data = Y->MutableData<T>();

Expand Down Expand Up @@ -215,7 +238,7 @@

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data,
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data,
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
}

Expand All @@ -234,10 +257,10 @@

is_packed = false;
if (input_idx == 1) { // scale
prepacked_scale_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
prepacked_scale_fp32_shape_ = tensor.Shape();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed);
} else if (input_idx == 2) { // bias
prepacked_bias_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
prepacked_bias_fp32_shape_ = tensor.Shape();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

Expand All @@ -249,9 +272,9 @@
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
size_t scale_size,
const TensorShape& scale_shape,
const T* bias_data,
size_t bias_size,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
Expand All @@ -263,23 +286,26 @@
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
int64_t scale_size = scale_shape.Size();
int64_t bias_size = bias_shape.Size();
int64_t broadcast_param = 0;

if (norm_size <= 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, norm_size);
} else if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckBroadcast(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, broadcast_param));
}

IAllocatorUniquePtr<float> scale_fp32;
IAllocatorUniquePtr<float> bias_fp32;
if constexpr (std::is_same_v<T, MLFloat16>) {
if (prepacked_scale_fp32_data_ == nullptr) {
const size_t num_elems = static_cast<size_t>(norm_size);
const size_t num_elems = static_cast<size_t>(scale_size);
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
}
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
const size_t num_elems = static_cast<size_t>(norm_size);
const size_t num_elems = static_cast<size_t>(bias_size);
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
Expand All @@ -288,7 +314,7 @@
concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size,
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, broadcast_param,
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel {
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
size_t scale_size,
const TensorShape& scale_shape,
const T* bias_data,
size_t bias_size,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev,
Expand Down Expand Up @@ -64,9 +64,9 @@ class LayerNormImpl : public OpKernel {
const bool simplified_;
const bool contrib_op_;
IAllocatorUniquePtr<float> prepacked_scale_fp32_data_;
size_t prepacked_scale_fp32_size_;
TensorShape prepacked_scale_fp32_shape_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
size_t prepacked_bias_fp32_size_;
TensorShape prepacked_bias_fp32_shape_;
};

} // namespace onnxruntime
27 changes: 16 additions & 11 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/nn/layer_norm.h"
#include "core/providers/cuda/nn/layer_norm_impl.h"
#include "core/providers/cpu/nn/layer_norm_helper.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -44,19 +45,22 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
const TensorShape& scale_shape = scale->Shape();

const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape();

int64_t broadcast_param = 0;
if (n2 <= 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, n2);
} else if (scale_shape.Size() != n2 || (bias_data && bias_shape.Size() != n2)) {
// Check if scale and bias can be broadcasted to X (only limited cases are supported).
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckBroadcast(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, broadcast_param));
}

// Outputs
Expand All @@ -65,7 +69,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -94,7 +98,8 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
X_data, n1, n2, epsilon_, scale_data, bias_data,
gsl::narrow_cast<int>(broadcast_param));
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
Loading
Loading