diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index ef5738be08..f1730de0e1 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) else os << delim; - if constexpr(std::is_same_v || std::is_same_v) + using RangeType = ck::remove_cvref_t; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) { os << ck::type_convert(v); } + else if constexpr(std::is_same_v) + { + const auto packed_floats = ck::type_convert(v); + const ck::vector_type vector_of_floats{packed_floats}; + os << vector_of_floats.template AsType()[ck::Number<0>{}] << delim + << vector_of_floats.template AsType()[ck::Number<1>{}]; + } else { os << static_cast(v); diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index f372756e68..9120ce62ca 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert(f8x2_ocp_ #endif } +template <> +inline __host__ __device__ float2_t type_convert(pk_i4_t x) +{ + uint8_t x_u8 = ck::bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; + + auto l_f32 = ck::type_convert(x_l); + auto h_f32 = ck::type_convert(x_h); + + return {l_f32, h_f32}; +} + template <> inline __host__ __device__ half2_t type_convert(float2_t x) { diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index ed7e86ded8..2054ffbbb3 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -177,7 +177,7 @@ bool profile_gemm_universal_impl(int do_verification, } } - if(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v) { // vector pk_i4x4 permute for(int i = 0; i < N; i++) @@ -188,7 +188,7 @@ bool profile_gemm_universal_impl(int do_verification, for(int k = 0; k < 4; k++) { - int i4x2 = b_k_n_permute(j + k * 2, i); + int i4x2 = b_k_n_permute(j + k * 2, i).data; input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf; }