Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix erase bug #659

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions include/cuco/detail/equal_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace detail {
* @brief Enum of equality comparison results
*/
// ENUM VALUE MATTERS, DO NOT CHANGE
enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, AVAILABLE = 3 };
enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, ERASED = 3 };

enum class is_insert : bool { YES, NO };

Expand Down Expand Up @@ -97,10 +97,13 @@ struct equal_wrapper {
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
{
if constexpr (IsInsert == is_insert::YES) {
return (cuco::detail::bitwise_compare(rhs, empty_sentinel_) or
cuco::detail::bitwise_compare(rhs, erased_sentinel_))
? equal_result::AVAILABLE
: this->equal_to(lhs, rhs);
if (cuco::detail::bitwise_compare(rhs, empty_sentinel_)) {
return equal_result::EMPTY;
} else if (cuco::detail::bitwise_compare(rhs, erased_sentinel_)) {
return equal_result::ERASED;
} else {
return this->equal_to(lhs, rhs);
}
} else {
return cuco::detail::bitwise_compare(rhs, empty_sentinel_) ? equal_result::EMPTY
: this->equal_to(lhs, rhs);
Expand Down
56 changes: 39 additions & 17 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,12 @@ class open_addressing_ref_impl {
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

[[maybe_unused]] auto probing_iter_copy = probing_iter;
[[maybe_unused]] bool erased = false;
[[maybe_unused]] bool empty_after_erased = false;

while (true) {
[[maybe_unused]] continue_after_erased:
auto const bucket_slots = storage_ref_[*probing_iter];

for (auto& slot_content : bucket_slots) {
Expand All @@ -393,21 +398,34 @@ class open_addressing_ref_impl {
if constexpr (not allows_duplicates) {
// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return false; }
if (eq_res == detail::equal_result::ERASED and not erased and not empty_after_erased) {
erased = true;
probing_iter_copy = probing_iter;
}
if (eq_res == detail::equal_result::EMPTY and erased and not empty_after_erased) {
empty_after_erased = true;
probing_iter = probing_iter_copy;
goto continue_after_erased;
}
}
if (eq_res == detail::equal_result::AVAILABLE) {
auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
slot_content,
val)) {
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;

if (not erased or empty_after_erased) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
switch (
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
slot_content,
val)) {
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;
}
}
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
}
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
}
}
}
Expand Down Expand Up @@ -442,8 +460,10 @@ class open_addressing_ref_impl {
for (auto i = 0; i < bucket_size; ++i) {
switch (
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(bucket_slots[i]))) {
case detail::equal_result::AVAILABLE:
return bucket_probing_results{detail::equal_result::AVAILABLE, i};
case detail::equal_result::EMPTY:
return bucket_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::ERASED:
return bucket_probing_results{detail::equal_result::ERASED, i};
case detail::equal_result::EQUAL: {
if constexpr (allows_duplicates) {
continue;
Expand All @@ -463,7 +483,8 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EQUAL)) { return false; }
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status =
Expand Down Expand Up @@ -538,7 +559,7 @@ class open_addressing_ref_impl {
}
return {iterator{&bucket_ptr[i]}, false};
}
if (eq_res == detail::equal_result::AVAILABLE) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
switch (this->attempt_insert_stable(bucket_ptr + i, bucket_slots[i], val)) {
case insert_result::SUCCESS: {
if constexpr (has_payload) {
Expand Down Expand Up @@ -626,7 +647,8 @@ class open_addressing_ref_impl {
return {iterator{reinterpret_cast<value_type*>(res)}, false};
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);
Expand Down
10 changes: 6 additions & 4 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class operator_impl<
payload_ref.store(val.second, cuda::memory_order_relaxed);
return;
}
if (eq_res == detail::equal_result::AVAILABLE) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
if (attempt_insert_or_assign(slot_ptr, val)) { return; }
}
}
Expand Down Expand Up @@ -571,7 +571,8 @@ class operator_impl<
return;
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status =
Expand Down Expand Up @@ -883,7 +884,7 @@ class operator_impl<
op(cuda::atomic_ref<T, Scope>{slot_ptr->second}, val.second);
return false;
}
if (eq_res == detail::equal_result::AVAILABLE) {
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
switch (ref_.attempt_insert_or_apply<UseDirectApply>(slot_ptr, slot_content, val, op)) {
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: {
Expand Down Expand Up @@ -970,7 +971,8 @@ class operator_impl<
return false;
}

auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
(state == detail::equal_result::ERASED));
if (group_contains_available) {
auto const src_lane = __ffs(group_contains_available) - 1;
auto const status = [&, target_idx = intra_bucket_index]() {
Expand Down
4 changes: 4 additions & 0 deletions tests/static_map/erase_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ void test_erase(Map& map, size_type num_keys)
REQUIRE(cuco::test::all_of(
d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{}));

// tests #606
map.insert(pairs_begin + num_keys / 2, pairs_begin + num_keys);
// TODO insert_and_find, insert_or_assign, insert_or_apply

map.erase(keys_begin + num_keys / 2, keys_begin + num_keys);
REQUIRE(map.size() == 0);
}
Expand Down
Loading