Skip to content

Commit

Permalink
collate memory accesses to avoid data races
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Jan 10, 2025
1 parent 8f9ee5e commit 56e16e6
Showing 1 changed file with 78 additions and 45 deletions.
123 changes: 78 additions & 45 deletions common/unified/components/range_minimum_query_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,45 @@ void compute_lookup_small(std::shared_ptr<const DefaultExecutor> exec,
static_assert(device_lut_type::type::num_trees <=
std::numeric_limits<tree_index_type>::max(),
"block type storage too small");
constexpr auto collation_width =
1 << (std::decay_t<decltype(block_argmin)>::bits_per_word_log2 -
ceil_log2_constexpr(ceil_log2_constexpr(small_block_size)));
const device_lut_type lut{exec};
constexpr auto infinity = std::numeric_limits<IndexType>::max();
run_kernel(
exec,
[] GKO_KERNEL(auto block_idx, auto values, auto block_argmin,
[] GKO_KERNEL(auto collated_block_idx, auto values, auto block_argmin,
auto block_min, auto block_type, auto lut, auto size) {
const auto i = block_idx * small_block_size;
IndexType local_values[small_block_size];
int argmin = 0;
const auto num_blocks = ceildiv(size, small_block_size);
for (auto block_idx = collated_block_idx * collation_width;
block_idx <
std::min<int64>((collated_block_idx + 1) * collation_width,
num_blocks);
block_idx++) {
const auto i = block_idx * small_block_size;
IndexType local_values[small_block_size];
int argmin = 0;
#pragma unroll
for (int local_i = 0; local_i < small_block_size; local_i++) {
// use "infinity" as sentinel for minimum computations
local_values[local_i] =
local_i + i < size ? values[local_i + i] : infinity;
if (local_values[local_i] < local_values[argmin]) {
argmin = local_i;
for (int local_i = 0; local_i < small_block_size; local_i++) {
// use "infinity" as sentinel for minimum computations
local_values[local_i] =
local_i + i < size ? values[local_i + i] : infinity;
if (local_values[local_i] < local_values[argmin]) {
argmin = local_i;
}
}
const auto tree_number = lut->compute_tree_index(local_values);
const auto min = local_values[argmin];
// TODO collate these so a single thread handles the argmins for
// an entire memory word
block_argmin.set(block_idx, argmin);
block_min[block_idx] = min;
block_type[block_idx] =
static_cast<tree_index_type>(tree_number);
}
const auto tree_number = lut->compute_tree_index(local_values);
const auto min = local_values[argmin];
// TODO collate these so a single thread handles the argmins for an
// entire memory word
block_argmin.set(block_idx, argmin);
block_min[block_idx] = min;
block_type[block_idx] = static_cast<tree_index_type>(tree_number);
},
ceildiv(size, small_block_size), values, block_argmin, block_min,
block_type, lut.get(), size);
ceildiv(ceildiv(size, small_block_size), collation_width), values,
block_argmin, block_min, block_type, lut.get(), size);
}

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
Expand All @@ -72,46 +83,68 @@ void compute_lookup_large(
range_minimum_query_superblocks<IndexType>& superblocks)
{
using superblock_type = range_minimum_query_superblocks<IndexType>;
using word_type = typename superblock_type::storage_type;
// we need to collate all writes that target the same memory word in a
// single thread
constexpr auto level0_collation_width = sizeof(word_type) * CHAR_BIT;
constexpr auto infinity = std::numeric_limits<IndexType>::max();
// initialize the first level of blocks
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto block_min, auto superblocks,
[] GKO_KERNEL(auto collated_i, auto block_min, auto superblocks,
auto num_blocks) {
const auto min1 = block_min[i];
const auto min2 = i + 1 < num_blocks ? block_min[i + 1] : infinity;
// we need to use <= here to make sure ties always break to the left
superblocks.set_block_argmin(0, i, min1 <= min2 ? 0 : 1);
// TODO collate these so a single thread handles the argmins for
// an entire memory word
for (auto i = collated_i * level0_collation_width;
i < std::min<int64>((collated_i + 1) * level0_collation_width,
num_blocks);
i++) {
const auto min1 = block_min[i];
const auto min2 =
i + 1 < num_blocks ? block_min[i + 1] : infinity;
// we need to use <= here to make sure ties always break to the
// left
superblocks.set_block_argmin(0, i, min1 <= min2 ? 0 : 1);
}
},
num_blocks, block_min, superblocks, num_blocks);
ceildiv(num_blocks, level0_collation_width), block_min, superblocks,
num_blocks);
// we computed argmins for blocks of size 2, now recursively combine them.
const auto num_levels = superblocks.num_levels();
for (int block_level = 1; block_level < num_levels; block_level++) {
const auto block_size =
superblock_type::block_size_for_level(block_level);
// we need block_level + 1 bits to represent values of size block_size
// and round up to the next power of two
const auto collation_width =
level0_collation_width / round_up_pow2(block_level + 1);
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto block_level, auto block_min,
auto superblocks, auto num_blocks) {
[] GKO_KERNEL(auto collated_i, auto block_level, auto block_min,
auto superblocks, auto num_blocks,
auto collation_width) {
const auto block_size =
superblock_type::block_size_for_level(block_level);
const auto i2 = i + block_size / 2;
const auto argmin1 =
i + superblocks.block_argmin(block_level - 1, i);
const auto argmin2 =
i2 < num_blocks
? i2 + superblocks.block_argmin(block_level - 1, i2)
: argmin1;
const auto min1 = block_min[argmin1];
const auto min2 = block_min[argmin2];
// we need to use <= here to make sure
// ties always break to the left
superblocks.set_block_argmin(
block_level, i, min1 <= min2 ? argmin1 - i : argmin2 - i);
// TODO collate these so a single thread handles the argmins for
// an entire memory word
for (auto i = collated_i * collation_width;
i < std::min<int64>((collated_i + 1) * collation_width,
num_blocks);
i++) {
const auto i2 = i + block_size / 2;
const auto argmin1 =
i + superblocks.block_argmin(block_level - 1, i);
const auto argmin2 =
i2 < num_blocks
? i2 + superblocks.block_argmin(block_level - 1, i2)
: argmin1;
const auto min1 = block_min[argmin1];
const auto min2 = block_min[argmin2];
// we need to use <= here to make sure
// ties always break to the left
superblocks.set_block_argmin(
block_level, i,
min1 <= min2 ? argmin1 - i : argmin2 - i);
}
},
num_blocks, block_level, block_min, superblocks, num_blocks);
ceildiv(num_blocks, collation_width), block_level, block_min,
superblocks, num_blocks, collation_width);
}
}

Expand Down

0 comments on commit 56e16e6

Please sign in to comment.