Skip to content

Commit

Permalink
Adds the ability to match leaf nodes in the Query Language (#44)
Browse files Browse the repository at this point in the history
* Adds support for finding leaves to object and string dialects

* Adds warnings when using negative values for comparing to depth and node_id

* Adds LeafCond and NotLeafCond to CypherQuery's _is_unary_cond function

* Adds a unit test for leaf querying
  • Loading branch information
ilumsden authored Sep 12, 2022
1 parent aaead67 commit a6178bd
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 2 deletions.
221 changes: 219 additions & 2 deletions hatchet/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from numbers import Real
import re
import sys
import warnings
import pandas as pd
from pandas import DataFrame
from pandas.core.indexes.multi import MultiIndex
Expand Down Expand Up @@ -178,6 +179,9 @@ def filter_single_series(df_row, key, single_value):
) and single_value.lower().startswith(compops):
return eval("{} {}".format(node._depth, single_value))
if isinstance(single_value, Real):
# If the value for "depth" is -1, check if the node is a leaf
if single_value == -1:
return len(node.children) == 0
return node._depth == single_value
raise InvalidQueryFilter(
"Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format(
Expand Down Expand Up @@ -675,9 +679,11 @@ def _apply_impl(self, gf, node, visited, matches):
OrCond: 'OR' subcond=UnaryCond;
UnaryCond: NotCond | SingleCond;
NotCond: 'NOT' subcond=SingleCond;
SingleCond: StringCond | NumberCond | NoneCond | NotNoneCond;
SingleCond: StringCond | NumberCond | NoneCond | NotNoneCond | LeafCond | NotLeafCond;
NoneCond: name=ID '.' prop=STRING 'IS NONE';
NotNoneCond: name=ID '.' prop=STRING 'IS NOT NONE';
LeafCond: name=ID 'IS LEAF';
NotLeafCond: name=ID 'IS NOT LEAF';
StringCond: StringEq | StringStartsWith | StringEndsWith | StringContains | StringMatch;
StringEq: name=ID '.' prop=STRING '=' val=STRING;
StringStartsWith: name=ID '.' prop=STRING 'STARTS WITH' val=STRING;
Expand Down Expand Up @@ -814,7 +820,7 @@ def _is_unary_cond(self, obj):
cname(obj) == "NotCond"
or self._is_str_cond(obj)
or self._is_num_cond(obj)
or cname(obj) in ["NoneCond", "NotNoneCond"]
or cname(obj) in ["NoneCond", "NotNoneCond", "LeafCond", "NotLeafCond"]
):
return True
return False
Expand Down Expand Up @@ -860,6 +866,10 @@ def _parse_single_cond(self, obj):
return self._parse_none(obj)
if cname(obj) == "NotNoneCond":
return self._parse_not_none(obj)
if cname(obj) == "LeafCond":
return self._parse_leaf(obj)
if cname(obj) == "NotLeafCond":
return self._parse_not_leaf(obj)
raise RuntimeError("Bad Single Condition")

def _parse_none(self, obj):
Expand Down Expand Up @@ -906,6 +916,22 @@ def _parse_not_none(self, obj):
None,
]

def _parse_leaf(self, obj):
return [
None,
obj.name,
"len(df_row.name.children) == 0",
None,
]

def _parse_not_leaf(self, obj):
return [
None,
obj.name,
"len(df_row.name.children) > 0",
None,
]

def _is_str_cond(self, obj):
if cname(obj) in [
"StringEq",
Expand Down Expand Up @@ -1008,13 +1034,56 @@ def _parse_num(self, obj):

def _parse_num_eq(self, obj):
if obj.prop == "depth":
if obj.val == -1:
return [
None,
obj.name,
"len(df_row.name.children) == 0",
None,
]
elif obj.val < 0:
warnings.warn(
"""
The 'depth' property of a Node is strictly non-negative.
This condition will always be false.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"False",
"isinstance(df_row.name._depth, Real)",
]
return [
None,
obj.name,
"df_row.name._depth == {}".format(obj.val),
"isinstance(df_row.name._depth, Real)",
]
if obj.prop == "node_id":
if obj.val < 0:
warnings.warn(
"""
The 'node_id' property of a Node is strictly non-negative.
This condition will always be false.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"False",
"isinstance(df_row.name._hatchet_nid, Real)",
]
return [
None,
obj.name,
Expand All @@ -1030,13 +1099,49 @@ def _parse_num_eq(self, obj):

def _parse_num_lt(self, obj):
if obj.prop == "depth":
if obj.val < 0:
warnings.warn(
"""
The 'depth' property of a Node is strictly non-negative.
This condition will always be false.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"False",
"isinstance(df_row.name._depth, Real)",
]
return [
None,
obj.name,
"df_row.name._depth < {}".format(obj.val),
"isinstance(df_row.name._depth, Real)",
]
if obj.prop == "node_id":
if obj.val < 0:
warnings.warn(
"""
The 'node_id' property of a Node is strictly non-negative.
This condition will always be false.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"False",
"isinstance(df_row.name._hatchet_nid, Real)",
]
return [
None,
obj.name,
Expand All @@ -1052,13 +1157,49 @@ def _parse_num_lt(self, obj):

def _parse_num_gt(self, obj):
if obj.prop == "depth":
if obj.val < 0:
warnings.warn(
"""
The 'depth' property of a Node is strictly non-negative.
This condition will always be true.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"True",
"isinstance(df_row.name._depth, Real)",
]
return [
None,
obj.name,
"df_row.name._depth > {}".format(obj.val),
"isinstance(df_row.name._depth, Real)",
]
if obj.prop == "node_id":
if obj.val < 0:
warnings.warn(
"""
The 'node_id' property of a Node is strictly non-negative.
This condition will always be true.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"True",
"isinstance(df_row.name._hatchet_nid, Real)",
]
return [
None,
obj.name,
Expand All @@ -1074,13 +1215,49 @@ def _parse_num_gt(self, obj):

def _parse_num_lte(self, obj):
if obj.prop == "depth":
if obj.val < 0:
warnings.warn(
"""
The 'depth' property of a Node is strictly non-negative.
This condition will always be false.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"False",
"isinstance(df_row.name._depth, Real)",
]
return [
None,
obj.name,
"df_row.name._depth <= {}".format(obj.val),
"isinstance(df_row.name._depth, Real)",
]
if obj.prop == "node_id":
if obj.val < 0:
warnings.warn(
"""
The 'node_id' property of a Node is strictly non-negative.
This condition will always be false.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"False",
"isinstance(df_row.name._hatchet_nid, Real)",
]
return [
None,
obj.name,
Expand All @@ -1096,13 +1273,49 @@ def _parse_num_lte(self, obj):

def _parse_num_gte(self, obj):
if obj.prop == "depth":
if obj.val < 0:
warnings.warn(
"""
The 'depth' property of a Node is strictly non-negative.
This condition will always be true.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"True",
"isinstance(df_row.name._depth, Real)",
]
return [
None,
obj.name,
"df_row.name._depth >= {}".format(obj.val),
"isinstance(df_row.name._depth, Real)",
]
if obj.prop == "node_id":
if obj.val < 0:
warnings.warn(
"""
The 'node_id' property of a Node is strictly non-negative.
This condition will always be true.
The statement that triggered this warning is:
{}
""".format(
obj
),
RedundantQueryFilterWarning,
)
return [
None,
obj.name,
"True",
"isinstance(df_row.name._hatchet_nid, Real)",
]
return [
None,
obj.name,
Expand Down Expand Up @@ -1509,5 +1722,9 @@ class InvalidQueryFilter(Exception):
"""Raised when a query filter does not have a valid syntax"""


class RedundantQueryFilterWarning(Warning):
"""Warned when a query filter does nothing or is redundant"""


class BadNumberNaryQueryArgs(Exception):
"""Raised when a query filter does not have a valid syntax"""
36 changes: 36 additions & 0 deletions hatchet/tests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,3 +1211,39 @@ def test_cypher_xor_compound_query(mock_graph_literal):
]
assert sorted(compound_query1.apply(gf)) == sorted(matches)
assert sorted(compound_query2.apply(gf)) == sorted(matches)


def test_leaf_query(small_mock2):
gf = GraphFrame.from_literal(small_mock2)
roots = gf.graph.roots
matches = [
roots[0].children[0].children[0],
roots[0].children[0].children[1],
roots[0].children[1].children[0],
roots[0].children[1].children[1],
]
nodes = set(gf.graph.traverse())
nonleaves = list(nodes - set(matches))
obj_query = QueryMatcher([{"depth": -1}])
str_query_numeric = parse_cypher_query(
u"""
MATCH (p)
WHERE p."depth" = -1
"""
)
str_query_is_leaf = parse_cypher_query(
u"""
MATCH (p)
WHERE p IS LEAF
"""
)
str_query_is_not_leaf = parse_cypher_query(
u"""
MATCH (p)
WHERE p IS NOT LEAF
"""
)
assert sorted(obj_query.apply(gf)) == sorted(matches)
assert sorted(str_query_numeric.apply(gf)) == sorted(matches)
assert sorted(str_query_is_leaf.apply(gf)) == sorted(matches)
assert sorted(str_query_is_not_leaf.apply(gf)) == sorted(nonleaves)

0 comments on commit a6178bd

Please sign in to comment.