diff --git a/core/test/base/extended_float.cpp b/core/test/base/extended_float.cpp index 6901ae72152..6098a70b728 100644 --- a/core/test/base/extended_float.cpp +++ b/core/test/base/extended_float.cpp @@ -204,18 +204,21 @@ TEST_F(FloatToHalf, TruncatesSmallNumber) } -TEST_F(FloatToHalf, TruncatesLargeNumber) +TEST_F(FloatToHalf, TruncatesLargeNumberRoundToEven) { - half x = create_from_bits("1" "10001110" "10010011111000010000100"); - - #if defined(SYCL_LANGUAGE_VERSION) && \ - (__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7)) - // TODO: sycl::half seems to did rounding, but ours just truncates - ASSERT_EQ(get_bits(x), get_bits("1" "11110" "1001010000")); - #else - ASSERT_EQ(get_bits(x), get_bits("1" "11110" "1001001111")); - #endif - + half neg_x = create_from_bits("1" "10001110" "10010011111000010000100"); + half neg_x2 = create_from_bits("1" "10001110" "10010011101000010000100"); + half x = create_from_bits("0" "10001110" "10010011111000010000100"); + half x2 = create_from_bits("0" "10001110" "10010011101000010000100"); + half x3 = create_from_bits("0" "10001110" "10010011101000000000000"); + half x4 = create_from_bits("0" "10001110" "10010011111000000000000"); + + EXPECT_EQ(get_bits(x), get_bits("0" "11110" "1001010000")); + EXPECT_EQ(get_bits(x2), get_bits("0" "11110" "1001001111")); + EXPECT_EQ(get_bits(x3), get_bits("0" "11110" "1001001110")); + EXPECT_EQ(get_bits(x4), get_bits("0" "11110" "1001010000")); + EXPECT_EQ(get_bits(neg_x), get_bits("1" "11110" "1001010000")); + EXPECT_EQ(get_bits(neg_x2), get_bits("1" "11110" "1001001111")); } diff --git a/include/ginkgo/core/base/half.hpp b/include/ginkgo/core/base/half.hpp index 1a8c1e1dfd1..446d085754d 100644 --- a/include/ginkgo/core/base/half.hpp +++ b/include/ginkgo/core/base/half.hpp @@ -462,8 +462,21 @@ class half { // TODO: handle denormals return conv::shift_sign(data_); } else { - return conv::shift_sign(data_) | exp | - conv::shift_significand(data_); + // Rounding to even + const auto result = conv::shift_sign(data_) | exp | + conv::shift_significand(data_); + // return result + ((result & 1) && + // ((data_ >> (f32_traits::significand_bits - + // f16_traits::significand_bits - 1)) & + // 1)); + const auto tail = + data_ & static_cast( + (1 << conv::significand_offset) - 1); + + constexpr auto half = static_cast( + 1 << (conv::significand_offset - 1)); + return result + + (tail > half || ((tail == half) && (result & 1))); } } } diff --git a/test/components/fill_array_kernels.cpp b/test/components/fill_array_kernels.cpp index 8ee0089c49c..bb7e195ad2c 100644 --- a/test/components/fill_array_kernels.cpp +++ b/test/components/fill_array_kernels.cpp @@ -53,7 +53,7 @@ class FillArray : public CommonTestFixture { protected: using value_type = T; FillArray() - : total_size(63531), + : total_size(3000), vals{ref, total_size}, dvals{exec, total_size}, seqs{ref, total_size} @@ -68,8 +68,8 @@ class FillArray : public CommonTestFixture { gko::array seqs; }; -TYPED_TEST_SUITE(FillArray, gko::test::ValueAndIndexTypes, - TypenameNameGenerator); +using LIST = ::testing::Types; +TYPED_TEST_SUITE(FillArray, LIST, TypenameNameGenerator); TYPED_TEST(FillArray, EqualsReference) @@ -88,5 +88,10 @@ TYPED_TEST(FillArray, FillSeqEqualsReference) gko::kernels::EXEC_NAMESPACE::components::fill_seq_array( this->exec, this->dvals.get_data(), this->total_size); + this->dvals.set_executor(this->ref); + for (gko::size_type i = 2000; i < this->total_size; i++) { + std::cout << i << " " << this->seqs.get_data()[i] << " device " + << this->dvals.get_data()[i] << std::endl; + } GKO_ASSERT_ARRAY_EQ(this->seqs, this->dvals); }