Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for Hatchet's "new-style" queries to Thicket #44

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions thicket/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 Lawrence Livermore National Security, LLC and other
# Thicket Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: MIT

# Make flake8 ignore unused names in this file
# flake8: noqa: F401

from hatchet.query import (
Query,
ObjectQuery,
StringQuery,
parse_string_dialect,
CompoundQuery,
ConjunctionQuery,
DisjunctionQuery,
ExclusiveDisjunctionQuery,
NegationQuery,
QueryEngine,
InvalidQueryPath,
InvalidQueryFilter,
RedundantQueryFilterWarning,
BadNumberNaryQueryArgs,
)

from hatchet.query.compat import (
AbstractQuery,
NaryQuery,
AndQuery,
IntersectionQuery,
OrQuery,
UnionQuery,
XorQuery,
SymDifferenceQuery,
NotQuery,
QueryMatcher,
CypherQuery,
parse_cypher_query,
)

from hatchet.query import is_hatchet_query


def is_thicket_query(query_obj):
return is_hatchet_query(query_obj)


__all__ = [
"Query",
"ObjectQuery",
"StringQuery",
"parse_string_dialect",
"CompoundQuery",
"ConjunctionQuery",
"DisjunctionQuery",
"ExclusiveDisjunctionQuery",
"NegationQuery",
"QueryEngine",
"InvalidQueryPath",
"InvalidQueryFilter",
"RedundantQueryFilterWarning",
"BadNumberNaryQueryArgs",
"is_thicket_query",
]
55 changes: 48 additions & 7 deletions thicket/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

import re

import hatchet as ht

from thicket import Thicket
from thicket.query import Query, QueryMatcher


def check_query(th, hnids, query):
def check_query(th, hnids, query, multi_index_mode):
"""Check query function for Thicket object.

Arguments:
Expand All @@ -27,7 +26,7 @@ def check_query(th, hnids, query):
match_frames = [node.frame for node in match]
match_names = [frame["name"] for frame in match_frames]
# Match all nodes using query
filt_th = th.query(query)
filt_th = th.query(query, multi_index_mode=multi_index_mode)
filt_nodes = list(filt_th.graph.traverse())

# Get filtered nodes and profiles
Expand All @@ -43,13 +42,55 @@ def check_query(th, hnids, query):
)


def test_query(rajaperf_basecuda_xl_cali):
def test_new_style_query_base(rajaperf_basecuda_xl_cali):
# test thicket
th = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali)
# test arguments
hnids = [0, 1, 2, 3, 5, 6, 8, 9]
query = (
Query()
.match("*")
.rel(
".",
lambda row: row["name"]
.apply(lambda x: re.match(r"Algorithm.*block_128", x) is not None)
.all(),
)
)

check_query(th, hnids, query, multi_index_mode="off")


def test_new_style_query_object(rajaperf_basecuda_xl_cali):
# test thicket
th = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali)
# test arguments
hnids = [0, 1, 2, 3, 5, 6, 8, 9]
query = ["*", {"name": "Algorithm.*block_128"}]

check_query(th, hnids, query, multi_index_mode="all")


def test_new_style_query_string(rajaperf_basecuda_xl_cali):
# test thicket
th = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali)
# test arguments
hnids = [0, 1, 2, 3, 5, 6, 8, 9]
query = """
MATCH ("*")->(p)
WHERE p."name" =~ "Algorithm.*block_128"
"""

check_query(th, hnids, query, multi_index_mode="all")


def test_old_style_query(rajaperf_basecuda_xl_cali):
# test thicket
th = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali)
# test arguments
hnids = [0, 1, 2, 3, 5, 6, 8, 9]
query = (
ht.QueryMatcher()
QueryMatcher()
.match("*")
.rel(
".",
Expand All @@ -59,4 +100,4 @@ def test_query(rajaperf_basecuda_xl_cali):
)
)

check_query(th, hnids, query)
check_query(th, hnids, query, multi_index_mode="off")
40 changes: 30 additions & 10 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
import pandas as pd
import numpy as np
from hatchet import GraphFrame
from hatchet.query import AbstractQuery

import thicket.helpers as helpers
from thicket.query import (
is_thicket_query,
ObjectQuery,
parse_string_dialect,
QueryEngine,
AbstractQuery,
)
from .utils import verify_sorted_profile
from .utils import verify_thicket_structures

Expand Down Expand Up @@ -68,6 +74,7 @@ def __init__(
self.statsframe = statsframe

self.performance_cols = helpers._get_perf_columns(self.dataframe)
self.query_engine = QueryEngine()

def __eq__(self, other):
"""Compare two thicket objects.
Expand Down Expand Up @@ -1087,7 +1094,9 @@ def filter(self, filter_func):
"Invalid function: thicket.filter(), please use thicket.filter_metadata() or thicket.filter_stats()"
)

def query(self, query_obj, squash=True, update_inc_cols=True):
def query(
self, query_obj, squash=True, update_inc_cols=True, multi_index_mode="all"
):
"""Apply a Hatchet query to the Thicket object.

Arguments:
Expand All @@ -1097,24 +1106,35 @@ def query(self, query_obj, squash=True, update_inc_cols=True):
the query
update_inc_cols (boolean, optional): if True, update inclusive columns when
performing squash.
multi_index_mode (str, optional): select how to aggregate the results of a predicate.
Can be "all" (default; requires the predicate to be True for all rows of data for a given
node), "any" (requires the predicate to be True for one or more rows of data for a given
node), or "off" (disables the use of query language dialects)

Returns:
(Thicket): a new Thicket object containing the data that matches the query
"""
if isinstance(query_obj, (list, str)):
raise UnsupportedQuery(
"Object and String queries from Hatchet are not yet supported in Thicket"
)
elif not issubclass(type(query_obj), AbstractQuery):
if not is_thicket_query(query_obj) or not isinstance(query_obj, (list, str)):
raise TypeError(
"Input to 'query' must be a Hatchet query (i.e., list, str, or subclass of AbstractQuery)"
"Input to 'query' must be a Hatchet query (i.e., list, str, or new- or old-style query object)"
)
if multi_index_mode == "off" and (
isinstance(query_obj, list) or isinstance(query_obj, str)
):
raise UnsupportedQuery(
"'Raw' object- and string-based dialect queries cannot be used when 'multi_index_mode' is set to 'off'"
)
dframe_copy = self.dataframe.copy()
index_names = self.dataframe.index.names
dframe_copy.reset_index(inplace=True)
query = query_obj
# TODO Add a conditional here to parse Object and String queries when supported
query_matches = query.apply(self)
if isinstance(query_obj, list):
query = ObjectQuery(query_obj, multi_index_mode=multi_index_mode)
elif isinstance(query_obj, str):
query = parse_string_dialect(query_obj, multi_index_mode=multi_index_mode)
elif issubclass(type(query_obj), AbstractQuery):
query = query_obj._get_new_query()
query_matches = self.query_engine.apply(query, self.graph, self.dataframe)
filtered_df = dframe_copy.loc[dframe_copy["node"].isin(query_matches)]
if filtered_df.shape[0] == 0:
raise EmptyQuery("The provided query would have produced an empty Thicket.")
Expand Down