Skip to content

Commit

Permalink
Bugfix _sync_nodes_frame (#99)
Browse files Browse the repository at this point in the history
* Sort graph and dataframe to avoid bug when there are multiple nodes with the same frame

* Add new validation function to check for invalid dataframes

* Remove limitation of only two index levels

* Change incorrect variable name
  • Loading branch information
michaelmckinsey1 authored Oct 31, 2023
1 parent 3de843d commit 83428d1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
13 changes: 8 additions & 5 deletions thicket/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd

import thicket.helpers as helpers
from .utils import verify_sorted_profile, verify_thicket_structures
from .utils import validate_dataframe, verify_sorted_profile, verify_thicket_structures


class Ensemble:
Expand Down Expand Up @@ -46,10 +46,10 @@ def _unify(thickets, inplace=False):
for i in range(len(_thickets)):
# Set all graphs to the union graph
_thickets[i].graph = union_graph
# Necessary to change dataframe hatchet id's to match the nodes in the graph
helpers._sync_nodes_frame(union_graph, _thickets[i].dataframe)
# For tree diff. dataframes need to be sorted.
_thickets[i].dataframe.sort_index(inplace=True)
# Necessary to change dataframe hatchet id's to match the nodes in the graph
helpers._sync_nodes_frame(union_graph, _thickets[i].dataframe)
return union_graph, _thickets

@staticmethod
Expand Down Expand Up @@ -285,6 +285,9 @@ def _handle_statsframe():
# Step 2D: Handle other Thicket objects.
_handle_misc()

# Validate dataframe
validate_dataframe(combined_th.dataframe)

return combined_th

@staticmethod
Expand Down Expand Up @@ -385,8 +388,8 @@ def _agg_to_set(obj):
unify_inc_metrics = list(set(unify_inc_metrics))
unify_exc_metrics = list(set(unify_exc_metrics))

# Workaround for graph/df node id mismatch.
helpers._sync_nodes(unify_graph, unify_df)
# Validate unify_df
validate_dataframe(unify_df)

unify_parts = (
unify_graph,
Expand Down
21 changes: 12 additions & 9 deletions thicket/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# SPDX-License-Identifier: MIT

import copy

import pandas as pd


Expand Down Expand Up @@ -152,17 +154,19 @@ def _sync_nodes_frame(gh, df):
TODO: This function may be superior to _sync_nodes and may be able to replace it.
Need to investigate.
"""
assert df.index.nlevels == 2 # For num_profiles assumption

# TODO: Graph function to list conversion: move to Hatchet?
gh_node_list = []
for gh_node in gh.traverse():
gh_node_list.append(gh_node)
# Sort the graph node list
gh_node_list.sort(key=lambda node: hash(node))

num_profiles = len(df.groupby(level=1))
index_names = df.index.names
df.reset_index(inplace=True)
df_node_list = df["node"][::num_profiles].to_list()
df_node_list = list(set(df.index.get_level_values("node")))
df_node_list_cp = copy.deepcopy(df_node_list)
# Check list sorted
assert sorted(df_node_list, key=lambda node: hash(node))

# Sequentially walk through graph and dataframe and modify dataframe hnid's based off graph equivalent
i = 0
Expand All @@ -176,12 +180,11 @@ def _sync_nodes_frame(gh, df):

# Extend list to match multi-index dataframe structure
df_list_full = []
for node in df_node_list:
temp = []
for idx in range(num_profiles):
temp.append(node)
df_list_full.extend(temp)
for i, node in enumerate(df_node_list):
num_profiles = len(df.loc[df_node_list_cp[i]])
df_list_full.extend([node] * num_profiles)
# Update nodes in the dataframe
df.reset_index(inplace=True)
df["node"] = df_list_full

df.set_index(index_names, inplace=True)
Expand Down
42 changes: 42 additions & 0 deletions thicket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,48 @@
from collections import OrderedDict


def validate_dataframe(df):
"""Check validity of a Thicket DataFrame."""

def _check_duplicate_inner_idx(df):
"""Check for duplicate values in the innermost index."""
for node in set(df.index.get_level_values("node")):
inner_idx_values = sorted(
df.loc[node].index.get_level_values(df.index.nlevels - 2).tolist()
)
inner_idx_values_set = sorted(list(set(inner_idx_values)))
if inner_idx_values != inner_idx_values_set:
raise IndexError(
f"The Thicket.dataframe's index has duplicate values. {inner_idx_values}"
)

def _check_missing_hnid(df):
"""Check if there are missing hatchet nid's."""
i = 0
set_of_nodes = set(df.index.get_level_values("node"))
for node in set_of_nodes:
if hash(node) != i:
raise ValueError(
f"The Thicket.dataframe's index is either not sorted or has a missing node. {hash(node)} ({node}) != {i}"
)
i += 1

def _validate_name_column(df):
"""Check if all of the values in a node's name column are either its name or None."""
for node in set(df.index.get_level_values("node")):
names = set(df.loc[node]["name"].tolist())
node_name = node.frame["name"]
for name in names:
if name != node_name and name is not None:
raise ValueError(
f"Value in the Thicket.dataframe's 'name' column is not valid. {name} != {node_name}"
)

_check_duplicate_inner_idx(df)
_check_missing_hnid(df)
_validate_name_column(df)


def verify_sorted_profile(thicket_component):
"""Assertion to check if profiles are sorted in a thicket dataframe
Expand Down

0 comments on commit 83428d1

Please sign in to comment.