Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run clients with ephemeral applications on the global event loop thread #9870

Closed
wants to merge 8 commits into from
56 changes: 55 additions & 1 deletion src/prefect/_internal/concurrency/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import atexit
import concurrent.futures
import itertools
from collections import deque
import sys
import queue
import threading
from typing import List, Optional
import contextlib
from typing import List, Optional, AsyncContextManager, TypeVar

from prefect._internal.concurrency.calls import Call, Portal
from prefect._internal.concurrency.event_loop import get_running_loop
Expand All @@ -16,6 +19,9 @@
from prefect._internal.concurrency import logger


T = TypeVar("T")


class WorkerThread(Portal):
"""
A portal to a worker running on a thread.
Expand Down Expand Up @@ -135,6 +141,7 @@ def __init__(
self._submitted_count: int = 0
self._on_shutdown: List[Call] = []
self._lock = threading.Lock()
self._calls = deque()

if not daemon:
atexit.register(self.shutdown)
Expand Down Expand Up @@ -172,8 +179,47 @@ def submit(self, call: Call) -> Call:
if self._run_once:
call.future.add_done_callback(lambda _: self.shutdown())

self._calls.append(call)

return call

@contextlib.contextmanager
def wrap_context(self, context: AsyncContextManager[T]) -> T:
enter = self.submit(Call.new(context.__aenter__))
exc_info = (None, None, None)
try:
yield enter.result()
except BaseException:
exc_info = sys.exc_info()
raise
finally:
self.submit(Call.new(context.__aexit__, *exc_info)).result()

@contextlib.asynccontextmanager
async def wrap_context_async(self, context: AsyncContextManager[T]) -> T:
enter = self.submit(Call.new(context.__aenter__))
exc_info = (None, None, None)
try:
yield await enter.aresult()
except BaseException:
exc_info = sys.exc_info()
raise
finally:
exit = self.submit(Call.new(context.__aexit__, *exc_info))
await exit.aresult()

def drain(self) -> None:
"""
Wait for the event loop to finish all outstanding work.
"""
# Wait for all calls to finish
while self._calls:
call = self._calls.popleft()
try:
call.result()
except Exception:
Fixed Show fixed Hide fixed
pass

def shutdown(self) -> None:
"""
Shutdown the worker thread. Does not wait for the thread to stop.
Expand Down Expand Up @@ -275,3 +321,11 @@ def wait_for_global_loop_exit(timeout: Optional[float] = None) -> None:
raise RuntimeError("Cannot wait for the loop thread from inside itself.")

loop_thread.thread.join(timeout)


def drain_global_loop(timeout: Optional[float] = None) -> None:
"""
Wait for the global event loop to finish all outstanding work.
"""
loop_thread = get_global_loop()
loop_thread.drain()
30 changes: 29 additions & 1 deletion src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from uuid import UUID

import asyncio
import httpcore
import httpx
import pendulum
Expand Down Expand Up @@ -105,6 +106,7 @@
from prefect.tasks import Task as TaskObject

from prefect.client.base import PrefectHttpxClient, app_lifespan_context
from prefect._internal.concurrency.threads import get_global_loop


class ServerType(AutoEnum):
Expand Down Expand Up @@ -2531,7 +2533,33 @@ async def __aenter__(self):
self.logger.debug(f"Connecting to API at {self.api_url}")

# Enter the httpx client's context
await self._exit_stack.enter_async_context(self._client)
if self._ephemeral_app:
# Enter the client on the global loop thread instead
global_loop = get_global_loop()
await self._exit_stack.enter_async_context(
global_loop.wrap_context_async(self._client)
)

# Then patch the `request` method to run calls over there
request = self._client.request

def request_on_global_loop(*args, **kwargs):
if global_loop._shutdown_event.is_set():
raise RuntimeError(
"The loop this client is bound to is shutdown and it cannot be "
"used anymore."
)

return asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(
request(*args, **kwargs), global_loop._loop
)
)

self._client.request = request_on_global_loop

else:
await self._exit_stack.enter_async_context(self._client)

self._started = True

Expand Down
8 changes: 4 additions & 4 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from prefect.states import is_state
from prefect._internal.concurrency.api import create_call, from_async, from_sync
from prefect._internal.concurrency.calls import get_current_call
from prefect._internal.concurrency.threads import wait_for_global_loop_exit
from prefect._internal.concurrency.threads import (
drain_global_loop,
)
from prefect._internal.concurrency.cancellation import CancelledError, get_deadline
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas import FlowRun, OrchestrationResult, TaskRun
Expand Down Expand Up @@ -181,9 +183,7 @@ def enter_flow_run_engine_from_flow_call(

# On completion of root flows, wait for the global thread to ensure that
# any work there is complete
done_callbacks = (
[create_call(wait_for_global_loop_exit)] if not is_subflow_run else None
)
done_callbacks = [create_call(drain_global_loop)] if not is_subflow_run else None

# WARNING: You must define any context managers here to pass to our concurrency
# api instead of entering them in here in the engine entrypoint. Otherwise, async
Expand Down