-
-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: refactored Starlette and FastAPI integration
- Loading branch information
Showing
10 changed files
with
738 additions
and
645 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""FastAPI extension for Advanced Alchemy. | ||
This module provides FastAPI integration for Advanced Alchemy, including session management, | ||
database migrations, and service utilities. | ||
""" | ||
|
||
from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils | ||
from advanced_alchemy.alembic.commands import AlembicCommands | ||
from advanced_alchemy.config import AlembicAsyncConfig, AlembicSyncConfig, AsyncSessionConfig, SyncSessionConfig | ||
from advanced_alchemy.extensions.fastapi.config import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig | ||
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy | ||
from advanced_alchemy.extensions.flask.cli import get_database_migration_plugin | ||
|
||
__all__ = ( | ||
"AdvancedAlchemy", | ||
"AlembicAsyncConfig", | ||
"AlembicCommands", | ||
"AlembicSyncConfig", | ||
"AsyncSessionConfig", | ||
"EngineConfig", | ||
"SQLAlchemyAsyncConfig", | ||
"SQLAlchemySyncConfig", | ||
"SyncSessionConfig", | ||
"base", | ||
"exceptions", | ||
"filters", | ||
"get_database_migration_plugin", | ||
"mixins", | ||
"operations", | ||
"repository", | ||
"service", | ||
"types", | ||
"utils", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,35 @@ | ||
import typer | ||
from click import Context | ||
from fastapi import FastAPI | ||
from __future__ import annotations | ||
|
||
from advanced_alchemy.extensions.starlette import AdvancedAlchemy | ||
from typing import TYPE_CHECKING, Optional, cast | ||
|
||
cli = typer.Typer() | ||
import click | ||
|
||
from advanced_alchemy.cli import add_migration_commands | ||
|
||
def get_advanced_alchemy_extension(app: FastAPI) -> AdvancedAlchemy: | ||
if TYPE_CHECKING: | ||
from fastapi import FastAPI | ||
|
||
from advanced_alchemy.extensions.fastapi.extension import AdvancedAlchemy | ||
|
||
|
||
def get_database_migration_plugin(app: FastAPI) -> AdvancedAlchemy: | ||
"""Retrieve the Advanced Alchemy extension from a FastAPI application instance.""" | ||
# Replace this with the actual logic to get the extension from the app | ||
for state_key in app.state.__dict__: | ||
if isinstance(app.state.__dict__[state_key], AdvancedAlchemy): | ||
return app.state.__dict__[state_key] | ||
raise RuntimeError("Advanced Alchemy extension not found in the application.") | ||
|
||
|
||
@cli.command() | ||
def database_migration(ctx: Context) -> None: | ||
"""Manage SQLAlchemy database migrations.""" | ||
app: FastAPI = ctx.obj["app"] | ||
extension = get_advanced_alchemy_extension(app) | ||
# ... (Implement migration commands using extension.configs) | ||
# Example: | ||
for config in extension.configs: | ||
# ... (Perform migration operations using config) | ||
pass | ||
|
||
|
||
# You can add more commands as needed | ||
from advanced_alchemy.exceptions import ImproperConfigurationError | ||
|
||
extension = cast("Optional[AdvancedAlchemy]", getattr(app.state, "advanced_alchemy", None)) | ||
if extension is None: | ||
msg = "Failed to initialize database CLI. The Advanced Alchemy extension is not properly configured." | ||
raise ImproperConfigurationError(msg) | ||
return extension | ||
|
||
|
||
def register_database_commands(app: FastAPI) -> click.Group: | ||
@click.group(name="database") | ||
@click.pass_context | ||
def database_group(ctx: click.Context) -> None: | ||
"""Manage SQLAlchemy database components.""" | ||
ctx.ensure_object(dict) | ||
ctx.obj["configs"] = get_database_migration_plugin(app).config | ||
|
||
add_migration_commands(database_group) | ||
return database_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,9 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Literal | ||
from __future__ import annotations | ||
|
||
from litestar.types import BeforeMessageSendHookHandler | ||
from advanced_alchemy.extensions.starlette import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemySyncConfig | ||
|
||
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig | ||
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig | ||
from advanced_alchemy.extensions.litestar.plugins.init.config.common import SESSION_SCOPE_KEY | ||
from advanced_alchemy.extensions.litestar.plugins.init.config.engine import EngineConfig | ||
|
||
|
||
@dataclass | ||
class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig): | ||
"""SQLAlchemy Async config for FastAPI.""" | ||
|
||
# ... (add FastAPI-specific config options) | ||
before_send_handler: BeforeMessageSendHookHandler | None | Literal["autocommit", "autocommit_include_redirects"] = ( | ||
None | ||
) | ||
"""Handler to call before the ASGI message is sent. | ||
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and | ||
uncommitted data. | ||
""" | ||
engine_dependency_key: str = "db_engine" | ||
"""Key to use for the dependency injection of database engines.""" | ||
session_dependency_key: str = "db_session" | ||
"""Key to use for the dependency injection of database sessions.""" | ||
engine_app_state_key: str = "db_engine" | ||
"""Key under which to store the SQLAlchemy engine in the application :class:`State <litestar.datastructures.State>` | ||
instance. | ||
""" | ||
session_maker_app_state_key: str = "session_maker_class" | ||
"""Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application | ||
:class:`State <litestar.datastructures.State>` instance. | ||
""" | ||
session_scope_key: str = SESSION_SCOPE_KEY | ||
"""Key under which to store the SQLAlchemy scope in the application.""" | ||
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride] | ||
"""Configuration for the SQLAlchemy engine. | ||
The configuration options are documented in the SQLAlchemy documentation. | ||
""" | ||
set_default_exception_handler: bool = True | ||
"""Sets the default exception handler on application start.""" | ||
|
||
|
||
@dataclass | ||
class SQLAlchemySyncConfig(_SQLAlchemySyncConfig): | ||
"""SQLAlchemy Sync config for FastAPI.""" | ||
|
||
# ... (add FastAPI-specific config options) | ||
before_send_handler: BeforeMessageSendHookHandler | None | Literal["autocommit", "autocommit_include_redirects"] = ( | ||
None | ||
) | ||
"""Handler to call before the ASGI message is sent. | ||
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and | ||
uncommitted data. | ||
""" | ||
engine_dependency_key: str = "db_engine" | ||
"""Key to use for the dependency injection of database engines.""" | ||
session_dependency_key: str = "db_session" | ||
"""Key to use for the dependency injection of database sessions.""" | ||
engine_app_state_key: str = "db_engine" | ||
"""Key under which to store the SQLAlchemy engine in the application :class:`State <litestar.datastructures.State>` | ||
instance. | ||
""" | ||
session_maker_app_state_key: str = "session_maker_class" | ||
"""Key under which to store the SQLAlchemy :class:`sessionmaker <sqlalchemy.orm.sessionmaker>` in the application | ||
:class:`State <litestar.datastructures.State>` instance. | ||
""" | ||
session_scope_key: str = SESSION_SCOPE_KEY | ||
"""Key under which to store the SQLAlchemy scope in the application.""" | ||
engine_config: EngineConfig = field(default_factory=EngineConfig) # pyright: ignore[reportIncompatibleVariableOverride] | ||
"""Configuration for the SQLAlchemy engine. | ||
The configuration options are documented in the SQLAlchemy documentation. | ||
""" | ||
set_default_exception_handler: bool = True | ||
"""Sets the default exception handler on application start.""" | ||
__all__ = ( | ||
"EngineConfig", | ||
"SQLAlchemyAsyncConfig", | ||
"SQLAlchemySyncConfig", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,53 @@ | ||
from starlette.applications import Starlette | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from __future__ import annotations | ||
|
||
from advanced_alchemy.extensions.fastapi.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig | ||
from typing import TYPE_CHECKING, Sequence | ||
|
||
from fastapi_cli.cli import app as fastapi_cli_app | ||
|
||
from advanced_alchemy.extensions.fastapi.cli import register_database_commands | ||
from advanced_alchemy.extensions.starlette import AdvancedAlchemy as StarletteAdvancedAlchemy | ||
|
||
if TYPE_CHECKING: | ||
from fastapi import FastAPI | ||
|
||
from advanced_alchemy.extensions.fastapi.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig | ||
|
||
__all__ = ("AdvancedAlchemy",) | ||
|
||
|
||
class AdvancedAlchemy: | ||
"""AdvancedAlchemy integration for Starlette/FastAPI applications. | ||
def assign_cli_group(app: FastAPI) -> None: | ||
from typer.main import get_group | ||
|
||
click_app = get_group(fastapi_cli_app) | ||
click_app.add_command(register_database_commands(app)) | ||
|
||
|
||
This class manages SQLAlchemy sessions and engine lifecycle within a Starlette/FastAPI application. | ||
class AdvancedAlchemy(StarletteAdvancedAlchemy): | ||
"""AdvancedAlchemy integration for FastAPI applications. | ||
This class manages SQLAlchemy sessions and engine lifecycle within a FastAPI application. | ||
It provides middleware for handling transactions based on commit strategies. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig | list[SQLAlchemyAsyncConfig | SQLAlchemySyncConfig], | ||
app: Starlette | None = None, | ||
config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig | Sequence[SQLAlchemyAsyncConfig | SQLAlchemySyncConfig], | ||
app: FastAPI | None = None, | ||
) -> None: | ||
self.configs: list[SQLAlchemyAsyncConfig | SQLAlchemySyncConfig] = ( | ||
[config] if not isinstance(config, list) else config | ||
) | ||
self._app: Starlette | ||
self.engine_keys: list[str] = [] | ||
self.sessionmaker_keys: list[str] = [] | ||
self.session_keys: list[str] = [] | ||
if app is not None: | ||
self.init_app(app) | ||
|
||
def init_app(self, app: Starlette) -> None: | ||
"""Initializes the Starlette/FastAPI application with SQLAlchemy engine and sessionmaker. | ||
super().__init__(config, app) | ||
|
||
def init_app(self, app: FastAPI) -> None: # type: ignore[override] | ||
"""Initializes the FastAPI application with SQLAlchemy engine and sessionmaker. | ||
Sets up middleware and shutdown handlers for managing the database engine. | ||
Args: | ||
app (starlette.applications.Starlette): The Starlette/FastAPI application instance. | ||
app (fastapi.FastAPI): The FastAPI application instance. | ||
""" | ||
for config in self.configs: | ||
engine = config.get_engine() | ||
engine_key = self._make_unique_state_key(app, f"sqla_engine_{engine.name}") | ||
sessionmaker_key = self._make_unique_state_key(app, f"sqla_sessionmaker_{engine.name}") | ||
session_key = f"sqla_session_{sessionmaker_key}" | ||
|
||
self.engine_keys.append(engine_key) | ||
self.sessionmaker_keys.append(sessionmaker_key) | ||
self.session_keys.append(session_key) | ||
|
||
setattr(app.state, engine_key, engine) | ||
setattr(app.state, sessionmaker_key, config.create_session_maker()) | ||
|
||
app.add_middleware(BaseHTTPMiddleware, dispatch=self.middleware_dispatch) | ||
app.add_event_handler("shutdown", self.on_shutdown) # pyright: ignore[reportUnknownMemberType] | ||
|
||
self._app = app | ||
super().init_app(app) | ||
assign_cli_group(app) | ||
app.state.advanced_alchemy = self | ||
|
||
# ... (rest of the class methods will be adapted) | ||
async def on_shutdown(self) -> None: | ||
await super().on_shutdown() | ||
delattr(self.app.state, "advanced_alchemy") |
Oops, something went wrong.