Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple models in sqlalchemy #30

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ The format is based on `Keep a Changelog <https://keepachangelog.com/en/1.0.0/>`
and this project adheres to `Semantic Versioning <https://semver.org/spec/v2.0.0.html>`_.


Unreleased
^^^^^^^^^^

* Sqlalchemy: support for queries with multiple entities


[0.6.0] - 2022-06-03
--------------------

Expand Down
67 changes: 59 additions & 8 deletions odata_query/sqlalchemy/shorthand.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,66 @@
from typing import List
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The imports aren't organized correctly. We use isort to automate this process, as described in https://github.com/gorilla-co/odata-query/blob/master/CONTRIBUTING.rst#local-development. I did notice that our Github Workflow wasn't configured to run on Pull Requests, otherwise this would've been more clear, sorry! This is improved for future PRs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. I did run blake but forgot to run isort. I had some issues with poetry, so I wasn't able to properly test the PR. Please do not hesitate to make any changes you think are necessary, or let me know if it is something I can do.

from typing import List, Optional, Dict

from sqlalchemy.sql.expression import ClauseElement, Select
from sqlalchemy.orm.attributes import InstrumentedAttribute

import sqlalchemy
from odata_query.grammar import ODataLexer, ODataParser # type: ignore

from odata_query.rewrite import AliasRewriter
from .sqlalchemy_clause import AstToSqlAlchemyClauseVisitor


def _get_joined_attrs(query: Select) -> List[str]:
return [str(join[0]) for join in query._setup_joins]


def apply_odata_query(query: ClauseElement, odata_query: str) -> ClauseElement:
def create_odata_ast(
odata_query: str,
name_entity_map: Optional[Dict[str, InstrumentedAttribute]] = None,
) -> ClauseElement:
"""
Shorthand for applying an OData query to a SQLAlchemy query.
Create an AST for a query

Args:
query: SQLAlchemy query to apply the OData query to.
odata_query: OData query string.
name_entity_map: Dict for rewriting, optional
Returns:
ClauseElement: The modified query
AST
"""
lexer = ODataLexer()
parser = ODataParser()
model = query.column_descriptions[0]["entity"]

ast = parser.parse(lexer.tokenize(odata_query))
if name_entity_map:
rewriter = AliasRewriter(
{alias: str(field) for alias, field in name_entity_map.items()}
)
ast = rewriter.visit(ast)
return ast


def apply_ast_query(query: ClauseElement, ast) -> ClauseElement:
"""
Apply a parsed OData query AST to a SQLAlchemy query.

Args:
query: SQLAlchemy query to apply the OData query to.
Returns:
ClauseElement: The modified query
"""
model = [
col["entity"] for col in query.column_descriptions if col["entity"] is not None
]

for col in query.column_descriptions:
if col["aliased"]:
aliased_insp = col["entity"]._aliased_insp
if aliased_insp._is_with_polymorphic:
model.extend(
[
getattr(col["entity"], sub_aliased_insp.class_.__name__)
for sub_aliased_insp in aliased_insp._with_polymorphic_entities
]
)

transformer = AstToSqlAlchemyClauseVisitor(model)
where_clause = transformer.visit(ast)

Expand All @@ -34,3 +69,19 @@ def apply_odata_query(query: ClauseElement, odata_query: str) -> ClauseElement:
query = query.join(j)

return query.filter(where_clause)


def apply_odata_query(query: ClauseElement, odata_query: str) -> ClauseElement:
"""
Shorthand for applying an OData query to a SQLAlchemy query.

Args:
query: SQLAlchemy query to apply the OData query to.
odata_query: OData query string.
Returns:
ClauseElement: The modified query
"""
ast = create_odata_ast(odata_query)
query = apply_ast_query(query, ast)

return query
44 changes: 36 additions & 8 deletions odata_query/sqlalchemy/sqlalchemy_clause.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import operator
from typing import Any, Callable, List, Optional, Type, Union
from collections.abc import Collection
from typing import Any, Callable, List, Optional, Type, Union, Dict

from sqlalchemy.inspection import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
Expand Down Expand Up @@ -37,19 +38,46 @@ class AstToSqlAlchemyClauseVisitor(visitor.NodeVisitor):
filter clause.

Args:
root_model: The root model of the query.
root_model: The root model of the query. It can be either a single
sqlalchemy ORM entity, a collection of entities, or a dict mapping
namespace to entities.
"""

def __init__(self, root_model: Type[DeclarativeMeta]):
self.root_model = root_model
def __init__(
self,
root_model: Union[
Type[DeclarativeMeta],
List[Type[DeclarativeMeta]],
Dict[str, Type[DeclarativeMeta]],
],
):
if isinstance(root_model, dict):
self._models = root_model
else:
if not isinstance(root_model, Collection):
root_model = [root_model]
self._models = {model.__name__: model for model in root_model}
self._models_set = {model.__name__ for model in root_model}
self.join_relationships: List[InstrumentedAttribute] = []

def visit_Identifier(self, node: ast.Identifier) -> ColumnClause:
":meta private:"
try:
return getattr(self.root_model, node.name)
except AttributeError:
raise ex.InvalidFieldException(node.name)

if node.namespace:
namespaces = self._models_set.intersection(node.namespace)
for namespace in namespaces:
try:
return getattr(self._models[namespace], node.name)
except AttributeError:
pass
else:
# Check all the models. Duplicate names are an issue
for model in self._models.values():
try:
return getattr(model, node.name)
except AttributeError:
pass
raise ex.InvalidFieldException(node.name)

def visit_Attribute(self, node: ast.Attribute) -> ColumnClause:
":meta private:"
Expand Down