Skip to content

Commit

Permalink
fix: hash object should properly handle null values as 'other'
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Oct 7, 2024
1 parent 9882acd commit cdd9b96
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 94 deletions.
109 changes: 57 additions & 52 deletions packages/vaex-core/src/hash_primitives.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,24 +303,26 @@ class hash_base : public hash_common<Derived, T, Hashmap<T, int64_t>> {
py::object key_array() {
py::array_t<key_type> output_array(this->length());
auto output = output_array.template mutable_unchecked<1>();
py::gil_scoped_release gil;
auto offsets = this->offsets();
size_t map_index = 0;
int64_t natural_order = 0;
// TODO: can be parallel due to non-overlapping maps
for (auto &map : this->maps) {
for (auto &el : map) {
key_type key = el.first;
int64_t index = static_cast<Derived &>(*this).key_offset(natural_order++, map_index, el, offsets[map_index]);
output(index) = key;
{
py::gil_scoped_release gil;
auto offsets = this->offsets();
size_t map_index = 0;
int64_t natural_order = 0;
// TODO: can be parallel due to non-overlapping maps
for (auto &map : this->maps) {
for (auto &el : map) {
key_type key = el.first;
int64_t index = static_cast<Derived &>(*this).key_offset(natural_order++, map_index, el, offsets[map_index]);
output(index) = key;
}
map_index += 1;
}
if (this->nan_count) {
output(this->nan_index()) = NaNish<key_type>::value;
}
if (this->null_count) {
output(this->null_index()) = -1;
}
map_index += 1;
}
if (this->nan_count) {
output(this->nan_index()) = NaNish<key_type>::value;
}
if (this->null_count) {
output(this->null_index()) = -1;
}
return output_array;
}
Expand Down Expand Up @@ -630,46 +632,49 @@ class ordered_set : public hash_base<ordered_set<T2, Hashmap2>, T2, Hashmap2> {
if (result.strides()[0] != result.itemsize()) {
throw std::runtime_error("stride not equal to bytesize for output");
}
py::gil_scoped_release gil;
{
py::gil_scoped_release gil;

size_t nmaps = this->maps.size();
auto offsets = this->offsets();
if (nmaps == 1) {
auto &map0 = this->maps[0];
for (int64_t i = 0; i < size; i++) {
const key_type &value = input[i];
// the caller is responsible for finding masked values
if (custom_isnan(value)) {
output[i] = this->nan_value;
// TODO: the test fail here because we pass in NaN for None?
// but of course only in debug mode
assert(this->nan_count > 0);
} else {
auto search = map0.find(value);
if (search == map0.end()) {
output[i] = -1;
size_t nmaps = this->maps.size();
auto offsets = this->offsets();
if (nmaps == 1) {
auto &map0 = this->maps[0];
for (int64_t i = 0; i < size; i++) {
const key_type &value = input[i];
// the caller is responsible for finding masked values
if (custom_isnan(value)) {
if(this->null_count > 0) {
output[i] = this->nan_value;
} else {
output[i] = -1;
}
} else {
output[i] = search->second;
auto search = map0.find(value);
if (search == map0.end()) {
output[i] = -1;
} else {
output[i] = search->second;
}
}
}
}
} else {
for (int64_t i = 0; i < size; i++) {
const key_type &value = input[i];
// the caller is responsible for finding masked values
if (custom_isnan(value)) {
output[i] = this->nan_value;
// TODO: the test fail here because we pass in NaN for None?
// but of course only in debug mode
assert(this->nan_count > 0);
} else {
std::size_t hash = hasher_map_choice()(value);
size_t map_index = (hash % nmaps);
auto search = this->maps[map_index].find(value);
if (search == this->maps[map_index].end()) {
output[i] = -1;
} else {
for (int64_t i = 0; i < size; i++) {
const key_type &value = input[i];
// the caller is responsible for finding masked values
if (custom_isnan(value)) {
output[i] = this->nan_value;
// TODO: the test fail here because we pass in NaN for None?
// but of course only in debug mode
assert(this->nan_count > 0);
} else {
output[i] = search->second + offsets[map_index];
std::size_t hash = hasher_map_choice()(value);
size_t map_index = (hash % nmaps);
auto search = this->maps[map_index].find(value);
if (search == this->maps[map_index].end()) {
output[i] = -1;
} else {
output[i] = search->second + offsets[map_index];
}
}
}
}
Expand Down
89 changes: 47 additions & 42 deletions packages/vaex-core/src/hash_string.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,19 +555,35 @@ class ordered_set : public hash_base<ordered_set<T>, T, T, V> {
return result;
}
auto output = result.template mutable_unchecked<1>();
py::gil_scoped_release gil;
size_t nmaps = this->maps.size();
auto offsets = this->offsets();

if (nmaps == 1) {
auto &map0 = this->maps[0];
// split slow and fast path
if (strings->has_null()) {
for (int64_t i = 0; i < size; i++) {
if (strings->is_null(i)) {
output(i) = this->null_value;
assert(this->null_count > 0);
} else {
{
py::gil_scoped_release gil;
if (nmaps == 1) {
auto &map0 = this->maps[0];
// split slow and fast path
if (strings->has_null()) {
for (int64_t i = 0; i < size; i++) {
if (strings->is_null(i)) {
if(this->null_count > 0) {
output(i) = this->null_value;
} else {
output(i) = -1;
}
} else {
const string_view &key = strings->view(i);
auto search = map0.find(key);
auto end = map0.end();
if (search == end) {
output(i) = -1;
} else {
output(i) = search->second;
}
}
}
} else {
for (int64_t i = 0; i < size; i++) {
const string_view &key = strings->view(i);
auto search = map0.find(key);
auto end = map0.end();
Expand All @@ -579,27 +595,29 @@ class ordered_set : public hash_base<ordered_set<T>, T, T, V> {
}
}
} else {
for (int64_t i = 0; i < size; i++) {
const string_view &key = strings->view(i);
auto search = map0.find(key);
auto end = map0.end();
if (search == end) {
output(i) = -1;
} else {
output(i) = search->second;
// split slow and fast path
if (strings->has_null()) {
for (int64_t i = 0; i < size; i++) {
if (strings->is_null(i)) {
output(i) = this->null_value;
assert(this->null_count > 0);
} else {
const string_view &key = strings->view(i);
size_t hash = hasher_map_choice()(key);
size_t map_index = (hash % nmaps);
auto search = this->maps[map_index].find(key, hash);
auto end = this->maps[map_index].end();
if (search == end) {
output(i) = -1;
} else {
output(i) = search->second + offsets[map_index];
}
}
}
}
}
} else {
// split slow and fast path
if (strings->has_null()) {
for (int64_t i = 0; i < size; i++) {
if (strings->is_null(i)) {
output(i) = this->null_value;
assert(this->null_count > 0);
} else {
} else {
for (int64_t i = 0; i < size; i++) {
const string_view &key = strings->view(i);
size_t hash = hasher_map_choice()(key);
std::size_t hash = hasher_map_choice()(key);
size_t map_index = (hash % nmaps);
auto search = this->maps[map_index].find(key, hash);
auto end = this->maps[map_index].end();
Expand All @@ -610,19 +628,6 @@ class ordered_set : public hash_base<ordered_set<T>, T, T, V> {
}
}
}
} else {
for (int64_t i = 0; i < size; i++) {
const string_view &key = strings->view(i);
std::size_t hash = hasher_map_choice()(key);
size_t map_index = (hash % nmaps);
auto search = this->maps[map_index].find(key, hash);
auto end = this->maps[map_index].end();
if (search == end) {
output(i) = -1;
} else {
output(i) = search->second + offsets[map_index];
}
}
}
}
return result;
Expand Down

0 comments on commit cdd9b96

Please sign in to comment.