Skip to content

Commit

Permalink
feat(django): make gis/geo libraries optional
Browse files Browse the repository at this point in the history
  • Loading branch information
OliverHofkens committed Dec 1, 2023
1 parent 7d62b9c commit 629b49c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
3 changes: 2 additions & 1 deletion odata_query/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Identifier(_Node):
namespace: Tuple[str, ...] = field(default_factory=tuple)

def full_name(self):
return '.'.join(self.namespace + (self.name,))
return ".".join(self.namespace + (self.name,))


@dataclass(frozen=True)
Expand Down Expand Up @@ -325,6 +325,7 @@ class NamedParam(_Node):
name: Identifier
param: _Node


@dataclass(frozen=True)
class Call(_Node):
func: Identifier
Expand Down
29 changes: 25 additions & 4 deletions odata_query/django/django_q.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Type, Union
from uuid import UUID

Expand All @@ -17,8 +18,16 @@
)
from django.db.models.expressions import Expression

from django.contrib.gis.geos import GEOSGeometry
from django.contrib.gis.db.models import functions as gis_functions
try:
# Django gis requires system level libraries, which not every user needs.
from django.contrib.gis.db.models import functions as gis_functions
from django.contrib.gis.geos import GEOSGeometry

_gis_error = None
except Exception as e:
gis_functions = None
GEOSGeometry = None
_gis_error = e

from odata_query import ast, exceptions as ex, typing, utils, visitor

Expand All @@ -43,6 +52,15 @@
}


@contextmanager
def requires_gis(*args, **kwargs):
if not gis_functions:
raise ImportError(
"Cannot use geography functions because GeoDjango failed to load."
) from _gis_error
yield


class AstToDjangoQVisitor(visitor.NodeVisitor):
"""
:class:`NodeVisitor` that transforms an :term:`AST` into a Django Q
Expand Down Expand Up @@ -258,7 +276,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
def visit_Call(self, node: ast.Call) -> Union[Expression, Q]:
":meta private:"

func_name = node.func.full_name().replace('.', '__')
func_name = node.func.full_name().replace(".", "__")

try:
q_gen = getattr(self, "djangofunc_" + func_name.lower())
Expand Down Expand Up @@ -317,12 +335,15 @@ def visit_CollectionLambda(self, node: ast.CollectionLambda) -> Q:
else:
raise NotImplementedError()

@requires_gis
def djangofunc_geo__intersects(self, a, b):
return Q(**{a.name + '__' + 'intersects': GEOSGeometry(b.wkt())})
return Q(**{a.name + "__" + "intersects": GEOSGeometry(b.wkt())})

@requires_gis
def djangofunc_geo__distance(self, a, b):
return gis_functions.Distance(a.name, GEOSGeometry(b.wkt()))

@requires_gis
def djangofunc_geo__length(self, a):
return gis_functions.Length(a.name)

Expand Down
2 changes: 1 addition & 1 deletion odata_query/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _function_call(self, func: ast.Identifier, args: List[ast._Node]):

func_name = func.full_name()

if func.namespace in ((), ('geo',)):
if func.namespace in ((), ("geo",)):
try:
n_args_exp = ODATA_FUNCTIONS[func_name]
except KeyError:
Expand Down

0 comments on commit 629b49c

Please sign in to comment.