From a69fe5e7ab6a9c4a9b839904261f02313fc503d4 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee Date: Mon, 12 Feb 2024 09:24:11 +0900 Subject: [PATCH] [GPU] Fix rms not to be fused if the size of target dimension is > WGS (#22768) ### Details: - Fix rms not to be fused if the size of target dimemnsion is > WGS ### Tickets: - 131958 --- .../intel_gpu/src/graph/impls/ocl/rms.cpp | 9 ++++ .../kernels/rms/rms_kernel_base.h | 1 + .../kernels/rms/rms_kernel_bfyx_opt.cpp | 47 +++++++++++++++---- .../kernels/rms/rms_kernel_bfyx_opt.h | 1 + .../src/plugin/transformations/rms_fusion.cpp | 16 ++++++- .../src/plugin/transformations/rms_fusion.hpp | 2 +- .../src/plugin/transformations_pipeline.cpp | 2 +- .../rms_norm_decomposition_test.cpp | 10 ++-- 8 files changed, 72 insertions(+), 16 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/rms.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/rms.cpp index 0d193ecb88bccc..d60d31a7376e66 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/rms.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/rms.cpp @@ -39,6 +39,7 @@ struct rms_impl : typed_primitive_impl_ocl { params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(1))); params.epsilon = primitive->epsilon; + params.ov_input_rank = static_cast(impl_param.get_input_layout().get_partial_shape().size()); return {params, optional_params}; } @@ -46,6 +47,14 @@ struct rms_impl : typed_primitive_impl_ocl { auto kernel_params = get_kernel_params(impl_param, true); (_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data); } + + static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) { + return impl_params; + } + + kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override { + return static_canonicalize_shapes(impl_params); + } }; namespace detail { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_base.h index 6adbd6aeb6657f..e00c9d36b8cb7d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_base.h @@ -13,6 +13,7 @@ namespace kernel_selector { struct rms_params : public base_params { rms_params() : base_params(KernelType::RMS) {} float epsilon = 0.0f; + int32_t ov_input_rank = -1; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp index ad49fd86370e0a..dc780bf9eb1b0d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.cpp @@ -31,7 +31,22 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc if (params.has_dynamic_tensors()) { const auto& input = params.inputs[0]; DimensionAccessHelper dims(input); - const std::string data_size = toVectorMulString({dims.x(), dims.y(), dims.z()}); + std::string data_size; + switch (params.ov_input_rank) { + case 1 : + data_size = dims.b(); + break; + case 2 : + data_size = dims.f(); + break; + case 3 : + data_size = dims.y(); + break; + default: + data_size = dims.x(); + break; + } + const std::string lws_0 = "get_local_size(0)"; jit.AddConstants({ MakeJitConstant("DATA_SIZE", data_size), @@ -47,7 +62,7 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc }); } jit.AddConstants({ - MakeJitConstant("VEC_SIZE", 8), + MakeJitConstant("VEC_SIZE", vec_size), MakeJitConstant("VLOAD", "CAT(vload, VEC_SIZE)"), MakeJitConstant("VSTORE", "CAT(vstore, VEC_SIZE)"), MakeJitConstant("INPUT_VEC_TYPE", "MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_SIZE)"), @@ -71,10 +86,26 @@ RMSKernelBase::DispatchData RMSKernelBfyxOpt::SetDefault(const rms_params& param dispatchData.maxSlmSize = max_lws; if (!params.has_dynamic_tensors()) { - dispatchData.dataSize = input.X().v * input.Y().v * input.Z().v; - dispatchData.dataCount = input.Batch().v * input.Feature().v; - dispatchData.slmSize = dispatchData.dataSize / 8; - dispatchData.leftovers = dispatchData.dataSize % 8; + // data size to be processed within a LWG + switch (params.ov_input_rank) { + case 1: + dispatchData.dataSize = input.Batch().v; + dispatchData.dataCount = 1; + case 2: + dispatchData.dataSize = input.Feature().v; + dispatchData.dataCount = input.Batch().v; + case 3: + dispatchData.dataSize = input.Y().v; + dispatchData.dataCount = input.Batch().v * input.Feature().v; + break; + default: + dispatchData.dataSize = input.X().v; + dispatchData.dataCount = input.Batch().v * input.Feature().v * input.Z().v * input.Y().v; + break; + } + + dispatchData.slmSize = dispatchData.dataSize / vec_size; + dispatchData.leftovers = dispatchData.dataSize % vec_size; dispatchData.gws[0] = dispatchData.slmSize; dispatchData.gws[1] = dispatchData.dataCount; @@ -96,12 +127,12 @@ bool RMSKernelBfyxOpt::Validate(const Params& p, const optional_params& o) const if (!gamma.is_dynamic()) { size_t data_size = gamma.LogicalSize(); - if (data_size < 8) { + if (data_size < vec_size) { return false; } auto local_mem_per_wi = 2 * BytesPerElement(params.inputs[0].GetDType()); auto max_lws = std::min(params.engineInfo.maxWorkGroupSize, params.engineInfo.maxLocalMemSize / local_mem_per_wi); - auto slm_size = data_size / 8; + auto slm_size = data_size / vec_size; if (slm_size > max_lws) { return false; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.h index a9b49c4c1cc654..0484ad322c5dea 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rms/rms_kernel_bfyx_opt.h @@ -21,5 +21,6 @@ class RMSKernelBfyxOpt : public RMSKernelBase { bool Validate(const Params&, const optional_params&) const override; DispatchData SetDefault(const rms_params& params) const override; JitConstants GetJitConstants(const rms_params& params, DispatchData dispatchData) const override; + const size_t vec_size = 8; }; } // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.cpp index bcd192454f3d3a..9c75396d432e30 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.cpp @@ -34,7 +34,7 @@ static std::function)> constant_value(const float targ }; } -RMSFusion::RMSFusion() { +RMSFusion::RMSFusion(uint64_t max_work_group_size) { using namespace ov::pass::pattern; // Detect RMS decomposition pattern @@ -82,6 +82,20 @@ RMSFusion::RMSFusion() { } const auto& gamma_node = pattern_map.at(gamma).get_node_shared_ptr(); + const auto& gamma_shape = gamma_node->get_output_partial_shape(0).to_shape(); + + const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr(); + const auto & axes = pattern_map.at(mean_axes).get_node_shared_ptr(); + auto axes_constant = std::dynamic_pointer_cast(axes); + auto axes_val = axes_constant->cast_vector(); + // allow last dimension only + if ((axes_val[0] != -1) && (axes_val[0] != (static_cast(mean_node->get_input_partial_shape(0).size()) - 1))) + return false; + + const int32_t vec_size = 8; + if (static_cast((gamma_shape.back() / vec_size)) > static_cast(max_work_group_size)) + return false; + auto output_type = m.get_match_root()->get_output_element_type(0); auto rms = std::make_shared(x_output, diff --git a/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.hpp b/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.hpp index 66f236f3f26c38..8b8ee4867c5bc8 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.hpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.hpp @@ -12,7 +12,7 @@ namespace intel_gpu { class RMSFusion : public ov::pass::MatcherPass { public: OPENVINO_RTTI("RMSFusion", "0"); - RMSFusion(); + RMSFusion(uint64_t max_work_group_size); }; } // namespace intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 69caa4ff53f531..0c6f174dc69832 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -702,7 +702,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + manager.register_pass(device_info.max_work_group_size); manager.register_pass(); manager.register_pass(); if (!device_info.supports_immad) diff --git a/src/plugins/intel_gpu/tests/unit/transformations/rms_norm_decomposition_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/rms_norm_decomposition_test.cpp index 26d8638d2b904e..2266f21f8a8f99 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/rms_norm_decomposition_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/rms_norm_decomposition_test.cpp @@ -37,7 +37,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest1) { auto comp = std::make_shared(mul2, ov::element::f16); model = std::make_shared(ov::NodeVector{comp}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(32); } { auto input = std::make_shared(ov::element::f32, ov::Shape{1, 2, 6}); @@ -66,7 +66,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest2) { auto comp = std::make_shared(mul2, ov::element::f16); model = std::make_shared(ov::NodeVector{comp}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(32); } } @@ -88,7 +88,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest3) { auto comp = std::make_shared(mul2, ov::element::f16); model = std::make_shared(ov::NodeVector{comp}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(32); } } @@ -110,7 +110,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest4) { auto comp = std::make_shared(mul2, ov::element::f16); model = std::make_shared(ov::NodeVector{comp}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(32); } } @@ -132,7 +132,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest5) { auto comp = std::make_shared(mul2, ov::element::f16); model = std::make_shared(ov::NodeVector{comp}, ov::ParameterVector{input}); - manager.register_pass(); + manager.register_pass(32); } { auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 6});