Skip to content

Commit

Permalink
Fix universal gemm profiler for pk_i4_t (#1790)
Browse files Browse the repository at this point in the history
* Fix universal gemm profiler for pk_i4_t

* fix
  • Loading branch information
bartekxk authored Jan 4, 2025
1 parent 37b3514 commit 888317e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
13 changes: 11 additions & 2 deletions include/ck/library/utility/host_tensor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else
os << delim;

if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
using RangeType = ck::remove_cvref_t<decltype(v)>;
if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
std::is_same_v<RangeType, ck::bhalf_t>)
{
os << ck::type_convert<float>(v);
}
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
<< vector_of_floats.template AsType<float>()[ck::Number<1>{}];
}
else
{
os << static_cast<T>(v);
Expand Down
15 changes: 14 additions & 1 deletion include/ck/utility/type_convert.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif
}

template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;

auto l_f32 = ck::type_convert<float>(x_l);
auto h_f32 = ck::type_convert<float>(x_h);

return {l_f32, h_f32};
}

template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
Expand Down
6 changes: 3 additions & 3 deletions profiler/include/profiler/profile_gemm_universal_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -177,7 +177,7 @@ bool profile_gemm_universal_impl(int do_verification,
}
}

if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
if constexpr(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
Expand All @@ -188,7 +188,7 @@ bool profile_gemm_universal_impl(int do_verification,

for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
Expand Down

0 comments on commit 888317e

Please sign in to comment.