From 68455ca3e2d7a911164f534f7d46153a107d12d8 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 26 Apr 2024 14:17:37 -0400 Subject: [PATCH] Add support for both row and column MultiIndex in dialects (#154) * Adds support for 'new-style' queries in Thicket * Adds unit test for MultiIndex column support in QL dialects * Allows "raw" string and object dialect queries to be used in Thicket.query * Adds license header and flake8 control comment to query.py * Restores the old "example_cali" pytest fixture as "simple_cali" * Updates unit tests to use the tip of develop for Hatchet * Changes new query tests to use rajaperf_seq_O3_1M_cali fixture * Adds ruff support in pyproject.toml --- .github/workflows/unit-tests.yaml | 1 + pyproject.toml | 17 +++++++ thicket/query.py | 72 +++++++++++++++++++++++++++++ thicket/tests/test_query.py | 76 +++++++++++++++++++++++++++++++ thicket/thicket.py | 45 ++++++++++++------ 5 files changed, 197 insertions(+), 14 deletions(-) create mode 100644 thicket/query.py diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 62fd2b16..28c130e2 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -37,6 +37,7 @@ jobs: run: | python -m pip install --upgrade pip pytest pip install -r requirements.txt + python -m pip install --upgrade --force-reinstall git+https://github.com/LLNL/hatchet.git@develop python setup.py install python setup.py build_ext --inplace python -m pip list diff --git a/pyproject.toml b/pyproject.toml index d1a0840f..03e2e1e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,23 @@ version = "2024.1.0" description = "A Python-based toolkit for analyzing ensemble performance data." license = "MIT" +[tool.ruff] +line-length = 88 +target-version = 'py37' +include = ['\.pyi?$'] +exclude = [ + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".tox", + ".venv", + "_build", + "buck-out", + "build", + "dist", +] + [tool.black] line-length = 88 target-version = ['py37'] diff --git a/thicket/query.py b/thicket/query.py new file mode 100644 index 00000000..37f6ddf9 --- /dev/null +++ b/thicket/query.py @@ -0,0 +1,72 @@ +# Copyright 2022 Lawrence Livermore National Security, LLC and other +# Thicket Project Developers. See the top-level LICENSE file for details. +# +# SPDX-License-Identifier: MIT + +# make flake8 unused names in this file. +# flake8: noqa: F401 + +from hatchet.query import ( + # New style queries + # ################# + # + # Core query types + Query, + ObjectQuery, + StringQuery, + parse_string_dialect, + # Compound queries + CompoundQuery, + ConjunctionQuery, + DisjunctionQuery, + ExclusiveDisjunctionQuery, + NegationQuery, + # Errors + InvalidQueryPath, + InvalidQueryFilter, + RedundantQueryFilterWarning, + BadNumberNaryQueryArgs, + # + # Old style queries + # ################# + AbstractQuery, + NaryQuery, + AndQuery, + IntersectionQuery, + OrQuery, + UnionQuery, + XorQuery, + SymDifferenceQuery, + NotQuery, + QueryMatcher, + CypherQuery, + parse_cypher_query, + is_hatchet_query, +) + +__all__ = [ + "Query", + "ObjectQuery", + "StringQuery", + "parse_string_dialect", + "CompoundQuery", + "ConjunctionQuery", + "DisjunctionQuery", + "ExclusiveDisjunctionQuery", + "NegationQuery", + "InvalidQueryPath", + "InvalidQueryFilter", + "RedundantQueryFilterWarning", + "BadNumberNaryQueryArgs", + "is_hatchet_query", +] + + +def is_new_style_query(query_obj): + return issubclass(type(query_obj), Query) or issubclass( + type(query_obj), CompoundQuery + ) + + +def is_old_style_query(query_obj): + return issubclass(type(query_obj), AbstractQuery) diff --git a/thicket/tests/test_query.py b/thicket/tests/test_query.py index 8e368f4e..d451f356 100644 --- a/thicket/tests/test_query.py +++ b/thicket/tests/test_query.py @@ -68,3 +68,79 @@ def test_query(rajaperf_cuda_block128_1M_cali): ) check_query(th, hnids, query) + + +def test_object_dialect_column_multi_index(rajaperf_seq_O3_1M_cali): + th1 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[0]) + th2 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[1]) + th_cj = Thicket.concat_thickets([th1, th2], axis="columns") + + query = [ + ("+", {(0, "Avg time/rank"): "> 10.0", (1, "Avg time/rank"): "> 10.0"}), + ] + + root = th_cj.graph.roots[0] + match = list( + set( + [ + root, # RAJAPerf + root.children[1], # RAJAPerf.Apps + root.children[2], # RAJAPerf.Basic + root.children[3], # RAJAPerf.Lcals + root.children[3].children[0], # RAJAPerf.Lcals.Lcals_DIFF_PREDICT + root.children[4], # RAJAPerf.Polybench + ] + ) + ) + + new_th = th_cj.query(query, multi_index_mode="all") + queried_nodes = list(new_th.graph.traverse()) + + match_frames = list(sorted([n.frame for n in match])) + queried_frames = list(sorted([n.frame for n in queried_nodes])) + + assert len(queried_nodes) == len(match) + assert all(m == q for m, q in zip(match_frames, queried_frames)) + idx = pd.IndexSlice + assert ( + (new_th.dataframe.loc[idx[queried_nodes, :], (0, "Avg time/rank")] > 10.0) + & (new_th.dataframe.loc[idx[queried_nodes, :], (1, "Avg time/rank")] > 10.0) + ).all() + + +def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali): + th1 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[0]) + th2 = Thicket.from_caliperreader(rajaperf_seq_O3_1M_cali[1]) + th_cj = Thicket.concat_thickets([th1, th2], axis="columns") + + query = """MATCH ("+", p) + WHERE p.(0, "Avg time/rank") > 10.0 AND p.(1, "Avg time/rank") > 10.0 + """ + + root = th_cj.graph.roots[0] + match = list( + set( + [ + root, # RAJAPerf + root.children[1], # RAJAPerf.Apps + root.children[2], # RAJAPerf.Basic + root.children[3], # RAJAPerf.Lcals + root.children[3].children[0], # RAJAPerf.Lcals.Lcals_DIFF_PREDICT + root.children[4], # RAJAPerf.Polybench + ] + ) + ) + + new_th = th_cj.query(query, multi_index_mode="all") + queried_nodes = list(new_th.graph.traverse()) + + match_frames = list(sorted([n.frame for n in match])) + queried_frames = list(sorted([n.frame for n in queried_nodes])) + + assert len(queried_nodes) == len(match) + assert all(m == q for m, q in zip(match_frames, queried_frames)) + idx = pd.IndexSlice + assert ( + (new_th.dataframe.loc[idx[queried_nodes, :], (0, "Avg time/rank")] > 10.0) + & (new_th.dataframe.loc[idx[queried_nodes, :], (1, "Avg time/rank")] > 10.0) + ).all() diff --git a/thicket/thicket.py b/thicket/thicket.py index 8e95673b..439e1324 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -15,7 +15,14 @@ import numpy as np from hatchet import GraphFrame from hatchet.graph import Graph -from hatchet.query import AbstractQuery, QueryMatcher +from hatchet.query import QueryEngine +from thicket.query import ( + Query, + ObjectQuery, + parse_string_dialect, + is_hatchet_query, + is_old_style_query, +) import tqdm from thicket.ensemble import Ensemble @@ -82,7 +89,7 @@ def __init__( ) else: self.statsframe = statsframe - + self.query_engine = QueryEngine() self.performance_cols = helpers._get_perf_columns(self.dataframe) def __eq__(self, other): @@ -1042,7 +1049,7 @@ def intersection(self): """ # Row that didn't exist will contain "None" in the name column. - query = QueryMatcher().match( + query = Query().match( ".", lambda row: row["name"].apply(lambda n: n is not None).all() ) intersected_th = self.query(query) @@ -1116,12 +1123,14 @@ def filter(self, filter_func): "Invalid function: thicket.filter(), please use thicket.filter_metadata() or thicket.filter_stats()" ) - def query(self, query_obj, squash=True, update_inc_cols=True): + def query( + self, query_obj, squash=True, update_inc_cols=True, multi_index_mode="off" + ): """Apply a Hatchet query to the Thicket object. Arguments: - query_obj (AbstractQuery): the query, represented as by a subclass of - Hatchet's AbstractQuery + query_obj (AbstractQuery, Query, or CompoundQuery): the query, represented as by Query, CompoundQuery, or (for legacy support) + a subclass of Hatchet's AbstractQuery squash (bool): if true, run Thicket.squash before returning the result of the query update_inc_cols (boolean, optional): if True, update inclusive columns when @@ -1130,20 +1139,28 @@ def query(self, query_obj, squash=True, update_inc_cols=True): Returns: (thicket): a new Thicket object containing the data that matches the query """ - if isinstance(query_obj, (list, str)): - raise UnsupportedQuery( - "Object and String queries from Hatchet are not yet supported in Thicket" + local_query_obj = query_obj + if isinstance(query_obj, list): + local_query_obj = ObjectQuery(query_obj, multi_index_mode=multi_index_mode) + elif isinstance(query_obj, str): + local_query_obj = parse_string_dialect( + query_obj, multi_index_mode=multi_index_mode ) - elif not issubclass(type(query_obj), AbstractQuery): + elif not is_hatchet_query(query_obj): raise TypeError( - "Input to 'query' must be a Hatchet query (i.e., list, str, or subclass of AbstractQuery)" + "Encountered unrecognized query type (expected Query, CompoundQuery, or AbstractQuery, got {})".format( + type(query_obj) + ) ) dframe_copy = self.dataframe.copy() index_names = self.dataframe.index.names dframe_copy.reset_index(inplace=True) - query = query_obj - # TODO Add a conditional here to parse Object and String queries when supported - query_matches = query.apply(self) + query = ( + local_query_obj + if not is_old_style_query(query_obj) + else local_query_obj._get_new_query() + ) + query_matches = self.query_engine.apply(query, self.graph, self.dataframe) filtered_df = dframe_copy.loc[dframe_copy["node"].isin(query_matches)] if filtered_df.shape[0] == 0: raise EmptyQuery("The provided query would have produced an empty Thicket.")