diff --git a/odata_query/ast.py b/odata_query/ast.py index 1189aaf..0e268d5 100644 --- a/odata_query/ast.py +++ b/odata_query/ast.py @@ -20,7 +20,7 @@ class Identifier(_Node): namespace: Tuple[str, ...] = field(default_factory=tuple) def full_name(self): - return '.'.join(self.namespace + (self.name,)) + return ".".join(self.namespace + (self.name,)) @dataclass(frozen=True) @@ -325,6 +325,7 @@ class NamedParam(_Node): name: Identifier param: _Node + @dataclass(frozen=True) class Call(_Node): func: Identifier diff --git a/odata_query/django/django_q.py b/odata_query/django/django_q.py index 4a5259e..3d0ab01 100644 --- a/odata_query/django/django_q.py +++ b/odata_query/django/django_q.py @@ -1,4 +1,5 @@ import operator +from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Type, Union from uuid import UUID @@ -17,8 +18,16 @@ ) from django.db.models.expressions import Expression -from django.contrib.gis.geos import GEOSGeometry -from django.contrib.gis.db.models import functions as gis_functions +try: + # Django gis requires system level libraries, which not every user needs. + from django.contrib.gis.db.models import functions as gis_functions + from django.contrib.gis.geos import GEOSGeometry + + _gis_error = None +except Exception as e: + gis_functions = None + GEOSGeometry = None + _gis_error = e from odata_query import ast, exceptions as ex, typing, utils, visitor @@ -43,6 +52,15 @@ } +@contextmanager +def requires_gis(*args, **kwargs): + if not gis_functions: + raise ImportError( + "Cannot use geography functions because GeoDjango failed to load." + ) from _gis_error + yield + + class AstToDjangoQVisitor(visitor.NodeVisitor): """ :class:`NodeVisitor` that transforms an :term:`AST` into a Django Q @@ -258,7 +276,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> str: def visit_Call(self, node: ast.Call) -> Union[Expression, Q]: ":meta private:" - func_name = node.func.full_name().replace('.', '__') + func_name = node.func.full_name().replace(".", "__") try: q_gen = getattr(self, "djangofunc_" + func_name.lower()) @@ -317,12 +335,15 @@ def visit_CollectionLambda(self, node: ast.CollectionLambda) -> Q: else: raise NotImplementedError() + @requires_gis def djangofunc_geo__intersects(self, a, b): - return Q(**{a.name + '__' + 'intersects': GEOSGeometry(b.wkt())}) + return Q(**{a.name + "__" + "intersects": GEOSGeometry(b.wkt())}) + @requires_gis def djangofunc_geo__distance(self, a, b): return gis_functions.Distance(a.name, GEOSGeometry(b.wkt())) + @requires_gis def djangofunc_geo__length(self, a): return gis_functions.Length(a.name) diff --git a/odata_query/grammar.py b/odata_query/grammar.py index 70e8d6d..3f88cbc 100644 --- a/odata_query/grammar.py +++ b/odata_query/grammar.py @@ -563,7 +563,7 @@ def _function_call(self, func: ast.Identifier, args: List[ast._Node]): func_name = func.full_name() - if func.namespace in ((), ('geo',)): + if func.namespace in ((), ("geo",)): try: n_args_exp = ODATA_FUNCTIONS[func_name] except KeyError: