Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow add middleware after app has started #16758

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def check_versions():

def initialize():
from modules import initialize_util
initialize_util.allow_add_middleware_after_start()
initialize_util.fix_torch_version()
initialize_util.fix_pytorch_lightning()
initialize_util.fix_asyncio_event_loop_policy()
Expand Down
40 changes: 37 additions & 3 deletions modules/initialize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re

from modules.timer import startup_timer
from modules import patches
from functools import wraps


def gradio_server_name():
Expand Down Expand Up @@ -191,11 +193,8 @@ def configure_opts_onchange():

def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware

app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly


def configure_cors_middleware(app):
Expand All @@ -213,3 +212,38 @@ def configure_cors_middleware(app):
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)


def allow_add_middleware_after_start():
from starlette.applications import Starlette

def add_middleware_wrapper(func):
"""Patch Starlette.add_middleware to allow for middleware to be added after the app has started

Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None.
We can force add new middleware by first setting middleware_stack to None, then adding the middleware.
When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack).

If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started:
the first way is to just set middleware_stack to None and then retry
the second manually insert the middleware into the user_middleware list without calling add_middleware
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
res = None
try:
res = func(self, *args, **kwargs)
except RuntimeError as _:
try:
self.middleware_stack = None
res = func(self, *args, **kwargs)
except RuntimeError as e:
print(f'Warning: "{e}", Retrying...')
from starlette.middleware import Middleware
self.user_middleware.insert(0, Middleware(*args, **kwargs))
self.middleware_stack = None # ensure middleware_stack in the event of concurrent requests
return res

return wrapper

patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))
Loading