Skip to content

Commit

Permalink
Add support for both row and column MultiIndex in dialects (#154)
Browse files Browse the repository at this point in the history
* Adds support for 'new-style' queries in Thicket

* Adds unit test for MultiIndex column support in QL dialects

* Allows "raw" string and object dialect queries to be used in Thicket.query

* Adds license header and flake8 control comment to query.py

* Restores the old "example_cali" pytest fixture as "simple_cali"

* Updates unit tests to use the tip of develop for Hatchet

* Changes new query tests to use rajaperf_seq_O3_1M_cali fixture

* Adds ruff support in pyproject.toml
  • Loading branch information
ilumsden authored Apr 26, 2024
1 parent bf7cbf4 commit 68455ca
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
run: |
python -m pip install --upgrade pip pytest
pip install -r requirements.txt
python -m pip install --upgrade --force-reinstall git+https://github.com/LLNL/hatchet.git@develop
python setup.py install
python setup.py build_ext --inplace
python -m pip list
Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,23 @@ version = "2024.1.0"
description = "A Python-based toolkit for analyzing ensemble performance data."
license = "MIT"

[tool.ruff]
line-length = 88
target-version = 'py37'
include = ['\.pyi?$']
exclude = [
".eggs",
".git",
".hg",
".mypy_cache",
".tox",
".venv",
"_build",
"buck-out",
"build",
"dist",
]

[tool.black]
line-length = 88
target-version = ['py37']
Expand Down
72 changes: 72 additions & 0 deletions thicket/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2022 Lawrence Livermore National Security, LLC and other
# Thicket Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

# make flake8 unused names in this file.
# flake8: noqa: F401

from hatchet.query import (
# New style queries
# #################
#
# Core query types
Query,
ObjectQuery,
StringQuery,
parse_string_dialect,
# Compound queries
CompoundQuery,
ConjunctionQuery,
DisjunctionQuery,
ExclusiveDisjunctionQuery,
NegationQuery,
# Errors
InvalidQueryPath,
InvalidQueryFilter,
RedundantQueryFilterWarning,
BadNumberNaryQueryArgs,
#
# Old style queries
# #################
AbstractQuery,
NaryQuery,
AndQuery,
IntersectionQuery,
OrQuery,
UnionQuery,
XorQuery,
SymDifferenceQuery,
NotQuery,
QueryMatcher,
CypherQuery,
parse_cypher_query,
is_hatchet_query,
)

__all__ = [
"Query",
"ObjectQuery",
"StringQuery",
"parse_string_dialect",
"CompoundQuery",
"ConjunctionQuery",
"DisjunctionQuery",
"ExclusiveDisjunctionQuery",
"NegationQuery",
"InvalidQueryPath",
"InvalidQueryFilter",
"RedundantQueryFilterWarning",
"BadNumberNaryQueryArgs",
"is_hatchet_query",
]


def is_new_style_query(query_obj):
return issubclass(type(query_obj), Query) or issubclass(
type(query_obj), CompoundQuery
)


def is_old_style_query(query_obj):
return issubclass(type(query_obj), AbstractQuery)
76 changes: 76 additions & 0 deletions thicket/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,79 @@ def test_query(rajaperf_cuda_block128_1M_cali):
)

check_query(th, hnids, query)


def test_object_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):
th1 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[0])
th2 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[1])
th_cj = Thicket.concat_thickets([th1, th2], axis="columns")

query = [
("+", {(0, "Avg time/rank"): "> 10.0", (1, "Avg time/rank"): "> 10.0"}),
]

root = th_cj.graph.roots[0]
match = list(
set(
[
root, # RAJAPerf
root.children[1], # RAJAPerf.Apps
root.children[2], # RAJAPerf.Basic
root.children[3], # RAJAPerf.Lcals
root.children[3].children[0], # RAJAPerf.Lcals.Lcals_DIFF_PREDICT
root.children[4], # RAJAPerf.Polybench
]
)
)

new_th = th_cj.query(query, multi_index_mode="all")
queried_nodes = list(new_th.graph.traverse())

match_frames = list(sorted([n.frame for n in match]))
queried_frames = list(sorted([n.frame for n in queried_nodes]))

assert len(queried_nodes) == len(match)
assert all(m == q for m, q in zip(match_frames, queried_frames))
idx = pd.IndexSlice
assert (
(new_th.dataframe.loc[idx[queried_nodes, :], (0, "Avg time/rank")] > 10.0)
& (new_th.dataframe.loc[idx[queried_nodes, :], (1, "Avg time/rank")] > 10.0)
).all()


def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):
th1 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[0])
th2 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[1])
th_cj = Thicket.concat_thickets([th1, th2], axis="columns")

query = """MATCH ("+", p)
WHERE p.(0, "Avg time/rank") > 10.0 AND p.(1, "Avg time/rank") > 10.0
"""

root = th_cj.graph.roots[0]
match = list(
set(
[
root, # RAJAPerf
root.children[1], # RAJAPerf.Apps
root.children[2], # RAJAPerf.Basic
root.children[3], # RAJAPerf.Lcals
root.children[3].children[0], # RAJAPerf.Lcals.Lcals_DIFF_PREDICT
root.children[4], # RAJAPerf.Polybench
]
)
)

new_th = th_cj.query(query, multi_index_mode="all")
queried_nodes = list(new_th.graph.traverse())

match_frames = list(sorted([n.frame for n in match]))
queried_frames = list(sorted([n.frame for n in queried_nodes]))

assert len(queried_nodes) == len(match)
assert all(m == q for m, q in zip(match_frames, queried_frames))
idx = pd.IndexSlice
assert (
(new_th.dataframe.loc[idx[queried_nodes, :], (0, "Avg time/rank")] > 10.0)
& (new_th.dataframe.loc[idx[queried_nodes, :], (1, "Avg time/rank")] > 10.0)
).all()
45 changes: 31 additions & 14 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
import numpy as np
from hatchet import GraphFrame
from hatchet.graph import Graph
from hatchet.query import AbstractQuery, QueryMatcher
from hatchet.query import QueryEngine
from thicket.query import (
Query,
ObjectQuery,
parse_string_dialect,
is_hatchet_query,
is_old_style_query,
)
import tqdm

from thicket.ensemble import Ensemble
Expand Down Expand Up @@ -82,7 +89,7 @@ def __init__(
)
else:
self.statsframe = statsframe

self.query_engine = QueryEngine()
self.performance_cols = helpers._get_perf_columns(self.dataframe)

def __eq__(self, other):
Expand Down Expand Up @@ -1042,7 +1049,7 @@ def intersection(self):
"""

# Row that didn't exist will contain "None" in the name column.
query = QueryMatcher().match(
query = Query().match(
".", lambda row: row["name"].apply(lambda n: n is not None).all()
)
intersected_th = self.query(query)
Expand Down Expand Up @@ -1116,12 +1123,14 @@ def filter(self, filter_func):
"Invalid function: thicket.filter(), please use thicket.filter_metadata() or thicket.filter_stats()"
)

def query(self, query_obj, squash=True, update_inc_cols=True):
def query(
self, query_obj, squash=True, update_inc_cols=True, multi_index_mode="off"
):
"""Apply a Hatchet query to the Thicket object.
Arguments:
query_obj (AbstractQuery): the query, represented as by a subclass of
Hatchet's AbstractQuery
query_obj (AbstractQuery, Query, or CompoundQuery): the query, represented as by Query, CompoundQuery, or (for legacy support)
a subclass of Hatchet's AbstractQuery
squash (bool): if true, run Thicket.squash before returning the result of
the query
update_inc_cols (boolean, optional): if True, update inclusive columns when
Expand All @@ -1130,20 +1139,28 @@ def query(self, query_obj, squash=True, update_inc_cols=True):
Returns:
(thicket): a new Thicket object containing the data that matches the query
"""
if isinstance(query_obj, (list, str)):
raise UnsupportedQuery(
"Object and String queries from Hatchet are not yet supported in Thicket"
local_query_obj = query_obj
if isinstance(query_obj, list):
local_query_obj = ObjectQuery(query_obj, multi_index_mode=multi_index_mode)
elif isinstance(query_obj, str):
local_query_obj = parse_string_dialect(
query_obj, multi_index_mode=multi_index_mode
)
elif not issubclass(type(query_obj), AbstractQuery):
elif not is_hatchet_query(query_obj):
raise TypeError(
"Input to 'query' must be a Hatchet query (i.e., list, str, or subclass of AbstractQuery)"
"Encountered unrecognized query type (expected Query, CompoundQuery, or AbstractQuery, got {})".format(
type(query_obj)
)
)
dframe_copy = self.dataframe.copy()
index_names = self.dataframe.index.names
dframe_copy.reset_index(inplace=True)
query = query_obj
# TODO Add a conditional here to parse Object and String queries when supported
query_matches = query.apply(self)
query = (
local_query_obj
if not is_old_style_query(query_obj)
else local_query_obj._get_new_query()
)
query_matches = self.query_engine.apply(query, self.graph, self.dataframe)
filtered_df = dframe_copy.loc[dframe_copy["node"].isin(query_matches)]
if filtered_df.shape[0] == 0:
raise EmptyQuery("The provided query would have produced an empty Thicket.")
Expand Down

0 comments on commit 68455ca

Please sign in to comment.