Skip to content

Commit

Permalink
Merge pull request #746 from danielballan/fix-in-empty-list
Browse files Browse the repository at this point in the history
Handle `In(key, [])` and `NotIn(key, [])` correctly.
  • Loading branch information
padraic-shafer authored May 21, 2024
2 parents 307524d + d76a83c commit 6ca698c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Write the date in place of the "Unreleased" in the case a new version is release
- Propagate setting `include_data_sources` into child nodes.
- Populate attributes in member data variables and coordinates of xarray Datasets.
- Update dependencies.
- Fix behavior of queries `In` and `NotIn` when passed an empty list of values.

## v0.1.0a120 (25 April 2024)

Expand Down
8 changes: 8 additions & 0 deletions tiled/_tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def test_in(client, query_values):
]


def test_in_empty(client):
assert list(client.search(In("letter", []))) == []


@pytest.mark.parametrize(
"query_values",
[
Expand All @@ -256,6 +260,10 @@ def test_notin(client, query_values):
)


def test_not_in_empty(client):
assert sorted(list(client.search(NotIn("letter", [])))) == sorted(list(client))


@pytest.mark.parametrize(
"include_values,exclude_values",
[
Expand Down
2 changes: 2 additions & 0 deletions tiled/adapters/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,8 @@ def notin(query: Any, tree: MapAdapter) -> MapAdapter:
"""
matches = {}
if len(query.value) == 0:
return tree
for key, value, term in iter_child_metadata(query.key, tree):
if term not in query.value:
matches[key] = value
Expand Down
65 changes: 53 additions & 12 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,24 @@
import uuid
from functools import partial, reduce
from pathlib import Path
from typing import Callable, Dict
from urllib.parse import quote_plus, urlparse

import anyio
from fastapi import HTTPException
from sqlalchemy import delete, event, func, not_, or_, select, text, type_coerce, update
from sqlalchemy import (
delete,
event,
false,
func,
not_,
or_,
select,
text,
true,
type_coerce,
update,
)
from sqlalchemy.dialects.postgresql import JSONB, REGCONFIG
from sqlalchemy.engine import make_url
from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -1206,22 +1219,40 @@ def specs(query, tree):
return tree.new_variation(conditions=tree.conditions + conditions)


def in_or_not_in(query, tree, method):
dialect_name = tree.engine.url.get_dialect().name
def in_or_not_in_sqlite(query, tree, method):
keys = query.key.split(".")
attr = orm.Node.metadata_[keys]
if dialect_name == "sqlite":
condition = getattr(_get_value(attr, type(query.value[0])), method)(query.value)
elif dialect_name == "postgresql":
# Engage btree_gin index with @> operator
if len(query.value) == 0:
if method == "in_":
# Results cannot possibly be "in" in an empty list,
# so put a False condition in the list ensuring that
# there are no rows return.
condition = false()
else: # method == "not_in"
# All results are always "not in" an empty list.
condition = true()
else:
condition = getattr(_get_value(attr, type(query.value[0])), method)(query.value)
return tree.new_variation(conditions=tree.conditions + [condition])


def in_or_not_in_postgresql(query, tree, method):
keys = query.key.split(".")
# Engage btree_gin index with @> operator
if method == "in_":
if len(query.value) == 0:
condition = false()
else:
condition = or_(
*(
orm.Node.metadata_.op("@>")(key_array_to_json(keys, item))
for item in query.value
)
)
elif method == "not_in":
elif method == "not_in":
if len(query.value) == 0:
condition = true()
else:
condition = not_(
or_(
*(
Expand All @@ -1230,13 +1261,23 @@ def in_or_not_in(query, tree, method):
)
)
)
else:
raise UnsupportedQueryType("NotIn")
else:
raise UnsupportedQueryType("In/NotIn")
return tree.new_variation(conditions=tree.conditions + [condition])


_IN_OR_NOT_IN_DIALECT_DISPATCH: Dict[str, Callable] = {
"sqlite": in_or_not_in_sqlite,
"postgresql": in_or_not_in_postgresql,
}


def in_or_not_in(query, tree, method):
METHODS = {"in_", "not_in"}
if method not in METHODS:
raise ValueError(f"method must be one of {METHODS}")
dialect_name = tree.engine.url.get_dialect().name
return _IN_OR_NOT_IN_DIALECT_DISPATCH[dialect_name](query, tree, method)


def keys_filter(query, tree):
condition = orm.Node.key.in_(query.keys)
return tree.new_variation(conditions=tree.conditions + [condition])
Expand Down

0 comments on commit 6ca698c

Please sign in to comment.