Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for both row and column MultiIndex in dialects #154

Merged
merged 8 commits into from
Apr 26, 2024
1 change: 1 addition & 0 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
run: |
python -m pip install --upgrade pip pytest
pip install -r requirements.txt
python -m pip install --upgrade --force-reinstall git+https://github.com/LLNL/hatchet.git@develop
python setup.py install
python setup.py build_ext --inplace
python -m pip list
Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,23 @@ version = "2024.1.0"
description = "A Python-based toolkit for analyzing ensemble performance data."
license = "MIT"

[tool.ruff]
line-length = 88
target-version = 'py37'
include = ['\.pyi?$']
exclude = [
".eggs",
".git",
".hg",
".mypy_cache",
".tox",
".venv",
"_build",
"buck-out",
"build",
"dist",
]

[tool.black]
line-length = 88
target-version = ['py37']
Expand Down
72 changes: 72 additions & 0 deletions thicket/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2022 Lawrence Livermore National Security, LLC and other
# Thicket Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

# make flake8 unused names in this file.
# flake8: noqa: F401

from hatchet.query import (
# New style queries
# #################
#
# Core query types
Query,
ObjectQuery,
StringQuery,
parse_string_dialect,
# Compound queries
CompoundQuery,
ConjunctionQuery,
DisjunctionQuery,
ExclusiveDisjunctionQuery,
NegationQuery,
# Errors
InvalidQueryPath,
InvalidQueryFilter,
RedundantQueryFilterWarning,
BadNumberNaryQueryArgs,
#
# Old style queries
# #################
AbstractQuery,
NaryQuery,
AndQuery,
IntersectionQuery,
OrQuery,
UnionQuery,
XorQuery,
SymDifferenceQuery,
NotQuery,
QueryMatcher,
CypherQuery,
parse_cypher_query,
is_hatchet_query,
)

__all__ = [
"Query",
"ObjectQuery",
"StringQuery",
"parse_string_dialect",
"CompoundQuery",
"ConjunctionQuery",
"DisjunctionQuery",
"ExclusiveDisjunctionQuery",
"NegationQuery",
"InvalidQueryPath",
"InvalidQueryFilter",
"RedundantQueryFilterWarning",
"BadNumberNaryQueryArgs",
"is_hatchet_query",
]


def is_new_style_query(query_obj):
return issubclass(type(query_obj), Query) or issubclass(
type(query_obj), CompoundQuery
)


def is_old_style_query(query_obj):
return issubclass(type(query_obj), AbstractQuery)
76 changes: 76 additions & 0 deletions thicket/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,79 @@ def test_query(rajaperf_cuda_block128_1M_cali):
)

check_query(th, hnids, query)


def test_object_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):
th1 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[0])
th2 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[1])
th_cj = Thicket.concat_thickets([th1, th2], axis="columns")

query = [
("+", {(0, "Avg time/rank"): "> 10.0", (1, "Avg time/rank"): "> 10.0"}),
]

root = th_cj.graph.roots[0]
match = list(
set(
[
root, # RAJAPerf
root.children[1], # RAJAPerf.Apps
root.children[2], # RAJAPerf.Basic
root.children[3], # RAJAPerf.Lcals
root.children[3].children[0], # RAJAPerf.Lcals.Lcals_DIFF_PREDICT
root.children[4], # RAJAPerf.Polybench
]
)
)

new_th = th_cj.query(query, multi_index_mode="all")
queried_nodes = list(new_th.graph.traverse())

match_frames = list(sorted([n.frame for n in match]))
queried_frames = list(sorted([n.frame for n in queried_nodes]))

assert len(queried_nodes) == len(match)
assert all(m == q for m, q in zip(match_frames, queried_frames))
idx = pd.IndexSlice
assert (
(new_th.dataframe.loc[idx[queried_nodes, :], (0, "Avg time/rank")] > 10.0)
& (new_th.dataframe.loc[idx[queried_nodes, :], (1, "Avg time/rank")] > 10.0)
).all()


def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):
th1 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[0])
th2 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[1])
th_cj = Thicket.concat_thickets([th1, th2], axis="columns")

query = """MATCH ("+", p)
WHERE p.(0, "Avg time/rank") > 10.0 AND p.(1, "Avg time/rank") > 10.0
"""

root = th_cj.graph.roots[0]
match = list(
set(
[
root, # RAJAPerf
root.children[1], # RAJAPerf.Apps
root.children[2], # RAJAPerf.Basic
root.children[3], # RAJAPerf.Lcals
root.children[3].children[0], # RAJAPerf.Lcals.Lcals_DIFF_PREDICT
root.children[4], # RAJAPerf.Polybench
]
)
)

new_th = th_cj.query(query, multi_index_mode="all")
queried_nodes = list(new_th.graph.traverse())

match_frames = list(sorted([n.frame for n in match]))
queried_frames = list(sorted([n.frame for n in queried_nodes]))

assert len(queried_nodes) == len(match)
assert all(m == q for m, q in zip(match_frames, queried_frames))
idx = pd.IndexSlice
assert (
(new_th.dataframe.loc[idx[queried_nodes, :], (0, "Avg time/rank")] > 10.0)
& (new_th.dataframe.loc[idx[queried_nodes, :], (1, "Avg time/rank")] > 10.0)
).all()
45 changes: 31 additions & 14 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
import numpy as np
from hatchet import GraphFrame
from hatchet.graph import Graph
from hatchet.query import AbstractQuery, QueryMatcher
from hatchet.query import QueryEngine
from thicket.query import (
Query,
ObjectQuery,
parse_string_dialect,
is_hatchet_query,
is_old_style_query,
)
import tqdm

from thicket.ensemble import Ensemble
Expand Down Expand Up @@ -77,7 +84,7 @@ def __init__(
)
else:
self.statsframe = statsframe

self.query_engine = QueryEngine()
self.performance_cols = helpers._get_perf_columns(self.dataframe)

def __eq__(self, other):
Expand Down Expand Up @@ -984,7 +991,7 @@ def intersection(self):
"""

# Row that didn't exist will contain "None" in the name column.
query = QueryMatcher().match(
query = Query().match(
".", lambda row: row["name"].apply(lambda n: n is not None).all()
)
intersected_th = self.query(query)
Expand Down Expand Up @@ -1058,12 +1065,14 @@ def filter(self, filter_func):
"Invalid function: thicket.filter(), please use thicket.filter_metadata() or thicket.filter_stats()"
)

def query(self, query_obj, squash=True, update_inc_cols=True):
def query(
self, query_obj, squash=True, update_inc_cols=True, multi_index_mode="off"
):
"""Apply a Hatchet query to the Thicket object.

Arguments:
query_obj (AbstractQuery): the query, represented as by a subclass of
Hatchet's AbstractQuery
query_obj (AbstractQuery, Query, or CompoundQuery): the query, represented as by Query, CompoundQuery, or (for legacy support)
a subclass of Hatchet's AbstractQuery
squash (bool): if true, run Thicket.squash before returning the result of
the query
update_inc_cols (boolean, optional): if True, update inclusive columns when
Expand All @@ -1072,20 +1081,28 @@ def query(self, query_obj, squash=True, update_inc_cols=True):
Returns:
(thicket): a new Thicket object containing the data that matches the query
"""
if isinstance(query_obj, (list, str)):
raise UnsupportedQuery(
"Object and String queries from Hatchet are not yet supported in Thicket"
local_query_obj = query_obj
if isinstance(query_obj, list):
local_query_obj = ObjectQuery(query_obj, multi_index_mode=multi_index_mode)
elif isinstance(query_obj, str):
local_query_obj = parse_string_dialect(
query_obj, multi_index_mode=multi_index_mode
)
elif not issubclass(type(query_obj), AbstractQuery):
elif not is_hatchet_query(query_obj):
raise TypeError(
"Input to 'query' must be a Hatchet query (i.e., list, str, or subclass of AbstractQuery)"
"Encountered unrecognized query type (expected Query, CompoundQuery, or AbstractQuery, got {})".format(
type(query_obj)
)
)
dframe_copy = self.dataframe.copy()
index_names = self.dataframe.index.names
dframe_copy.reset_index(inplace=True)
query = query_obj
# TODO Add a conditional here to parse Object and String queries when supported
query_matches = query.apply(self)
query = (
local_query_obj
if not is_old_style_query(query_obj)
else local_query_obj._get_new_query()
)
query_matches = self.query_engine.apply(query, self.graph, self.dataframe)
filtered_df = dframe_copy.loc[dframe_copy["node"].isin(query_matches)]
if filtered_df.shape[0] == 0:
raise EmptyQuery("The provided query would have produced an empty Thicket.")
Expand Down
Loading