diff --git a/.github/workflows/python-tests.yaml b/.github/workflows/python-tests.yaml index 2f388e0e72e7..7a5d863f2d0a 100644 --- a/.github/workflows/python-tests.yaml +++ b/.github/workflows/python-tests.yaml @@ -224,7 +224,7 @@ jobs: # Parallelize tests by scope to reduce expensive service fixture duplication # Do not allow the test suite to build images, as we want the prebuilt images to be tested # Do not run Kubernetes service tests, we do not have a cluster available - pytest tests -vvv --numprocesses auto --dist loadscope --disable-docker-image-builds --exclude-service kubernetes --durations=25 --cov=src/ --cov=tests/ --no-cov-on-fail --cov-report=term --cov-config=setup.cfg ${{ matrix.pytest-options }} + pytest tests -x -vvv --numprocesses auto --dist loadscope --disable-docker-image-builds --exclude-service kubernetes --durations=25 --cov=src/ --cov=tests/ --no-cov-on-fail --cov-report=term --cov-config=setup.cfg ${{ matrix.pytest-options }} - name: Check database container # Only applicable for Postgres, but we want this to run even when tests fail diff --git a/src/prefect/_internal/concurrency/api.py b/src/prefect/_internal/concurrency/api.py index 106bb28f8b5e..5436f0f10a88 100644 --- a/src/prefect/_internal/concurrency/api.py +++ b/src/prefect/_internal/concurrency/api.py @@ -168,12 +168,15 @@ async def wait_for_call_in_loop_thread( __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call]] = None, + cancel_callbacks: Optional[Iterable[Call]] = None, contexts: Optional[Iterable[ContextManager]] = None, ) -> Awaitable[T]: call = _cast_to_call(__call) waiter = AsyncWaiter(call) for callback in done_callbacks or []: waiter.add_done_callback(callback) + for callback in cancel_callbacks or []: + waiter.add_cancel_callback(callback) _base.call_soon_in_loop_thread(call, timeout=timeout) with contextlib.ExitStack() as stack: for context in contexts or []: diff --git a/src/prefect/_internal/concurrency/threads.py b/src/prefect/_internal/concurrency/threads.py index 5320d560e53c..542d2ca68482 100644 --- a/src/prefect/_internal/concurrency/threads.py +++ b/src/prefect/_internal/concurrency/threads.py @@ -5,9 +5,12 @@ import atexit import concurrent.futures import itertools +from collections import deque +import sys import queue import threading -from typing import List, Optional +import contextlib +from typing import List, Optional, AsyncContextManager, TypeVar from prefect._internal.concurrency.calls import Call, Portal from prefect._internal.concurrency.event_loop import get_running_loop @@ -16,6 +19,9 @@ from prefect._internal.concurrency import logger +T = TypeVar("T") + + class WorkerThread(Portal): """ A portal to a worker running on a thread. @@ -135,6 +141,7 @@ def __init__( self._submitted_count: int = 0 self._on_shutdown: List[Call] = [] self._lock = threading.Lock() + self._futures = deque() if not daemon: atexit.register(self.shutdown) @@ -166,7 +173,9 @@ def submit(self, call: Call) -> Call: call.set_runner(self) # Submit the call to the event loop - asyncio.run_coroutine_threadsafe(self._run_call(call), self._loop) + self._futures.append( + asyncio.run_coroutine_threadsafe(self._run_call(call), self._loop) + ) self._submitted_count += 1 if self._run_once: @@ -174,6 +183,47 @@ def submit(self, call: Call) -> Call: return call + @contextlib.contextmanager + def wrap_context(self, context: AsyncContextManager[T]) -> T: + enter = self.submit(Call.new(context.__aenter__)) + exc_info = (None, None, None) + try: + yield enter.result() + except BaseException: + exc_info = sys.exc_info() + raise + finally: + logger.debug("Exiting context %r", context) + self.submit(Call.new(context.__aexit__, *exc_info)).result() + + @contextlib.asynccontextmanager + async def wrap_context_async(self, context: AsyncContextManager[T]) -> T: + enter = self.submit(Call.new(context.__aenter__)) + exc_info = (None, None, None) + try: + yield await enter.aresult() + except BaseException: + exc_info = sys.exc_info() + raise + finally: + logger.debug("Exiting context %r", context) + exit = self.submit(Call.new(context.__aexit__, *exc_info)) + await exit.aresult() + + def drain(self) -> None: + """ + Wait for the event loop to finish all outstanding work. + """ + # Wait for all calls to finish + concurrent.futures.wait(self._futures) + + def cancel(self): + """ + Cancel all outstanding calls. + """ + for future in self._futures: + future.cancel() + def shutdown(self) -> None: """ Shutdown the worker thread. Does not wait for the thread to stop. @@ -275,3 +325,19 @@ def wait_for_global_loop_exit(timeout: Optional[float] = None) -> None: raise RuntimeError("Cannot wait for the loop thread from inside itself.") loop_thread.thread.join(timeout) + + +def drain_global_loop(timeout: Optional[float] = None) -> None: + """ + Wait for the global event loop to finish all outstanding work. + """ + loop_thread = get_global_loop() + loop_thread.drain() + + +def cancel_global_loop() -> None: + """ + Cancel all outstanding work in the global loop. + """ + loop_thread = get_global_loop() + loop_thread.cancel() diff --git a/src/prefect/_internal/concurrency/waiters.py b/src/prefect/_internal/concurrency/waiters.py index 2edad9103262..7c87726433cf 100644 --- a/src/prefect/_internal/concurrency/waiters.py +++ b/src/prefect/_internal/concurrency/waiters.py @@ -171,6 +171,7 @@ def __init__(self, call: Call[T]) -> None: self._done_callbacks = [] self._done_event = Event() self._done_waiting = False + self._cancel_callbacks = [] def submit(self, call: Call): """ @@ -227,15 +228,31 @@ async def _handle_done_callbacks(self): try: yield finally: - # Call done callbacks while self._done_callbacks: callback = self._done_callbacks.pop() if callback: + logger.debug("%r executing done callback %r", self, callback) # We shield against cancellation so we can run the callback with anyio.CancelScope(shield=True): - await self._run_done_callback(callback) + await self._run_callback(callback) - async def _run_done_callback(self, callback: Call): + @contextlib.asynccontextmanager + async def _handle_cancel_callbacks(self): + try: + yield + except asyncio.CancelledError: + while self._cancel_callbacks: + callback = self._cancel_callbacks.pop() + if callback: + logger.debug("%r executing cancel callback %r", self, callback) + # We shield against cancellation so we can run the callback + with anyio.CancelScope(shield=True): + await self._run_callback(callback) + + # Don't forget to re-raise the exception! + raise + + async def _run_callback(self, callback: Call): coro = callback.run() if coro: await coro @@ -246,6 +263,12 @@ def add_done_callback(self, callback: Call): else: self._done_callbacks.append(callback) + def add_cancel_callback(self, callback: Call): + if self._done_event.is_set(): + raise RuntimeError("Cannot add cancel callbacks to done waiters.") + else: + self._cancel_callbacks.append(callback) + def _signal_stop_waiting(self): # Only send a `None` to the queue if the waiter is still blocked reading from # the queue. Otherwise, it's possible that the event loop is stopped. @@ -263,10 +286,11 @@ async def wait(self) -> Call[T]: self._call.future.add_done_callback(lambda _: self._done_event.set()) async with self._handle_done_callbacks(): - await self._handle_waiting_callbacks() + async with self._handle_cancel_callbacks(): + await self._handle_waiting_callbacks() - # Wait for the future to be done - await self._done_event.wait() + # Wait for the future to be done + await self._done_event.wait() _WAITERS_BY_THREAD[self._owner_thread].remove(self) return self._call diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index efbc79190583..777a04203015 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -105,6 +105,7 @@ from prefect.tasks import Task as TaskObject from prefect.client.base import PrefectHttpxClient, app_lifespan_context +from prefect._internal.concurrency.threads import get_global_loop class ServerType(AutoEnum): @@ -2531,7 +2532,30 @@ async def __aenter__(self): self.logger.debug(f"Connecting to API at {self.api_url}") # Enter the httpx client's context - await self._exit_stack.enter_async_context(self._client) + if self._ephemeral_app: + # Enter the client on the global loop thread instead + global_loop = get_global_loop() + await self._exit_stack.enter_async_context( + global_loop.wrap_context_async(self._client) + ) + + # Then patch the `request` method to run calls over there + request = self._client.request + + def request_on_global_loop(*args, **kwargs): + from prefect._internal.concurrency import logger + from prefect._internal.concurrency.api import create_call + + logger.info(f"Sending request {args[0]} {args[1]}") + + return global_loop.submit( + create_call(request, *args, **kwargs) + ).aresult() + + self._client.request = request_on_global_loop + + else: + await self._exit_stack.enter_async_context(self._client) self._started = True diff --git a/src/prefect/engine.py b/src/prefect/engine.py index be53a4c58f4d..c662fcf2697d 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -40,7 +40,7 @@ from prefect.states import is_state from prefect._internal.concurrency.api import create_call, from_async, from_sync from prefect._internal.concurrency.calls import get_current_call -from prefect._internal.concurrency.threads import wait_for_global_loop_exit +from prefect._internal.concurrency.threads import drain_global_loop, cancel_global_loop from prefect._internal.concurrency.cancellation import CancelledError, get_deadline from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas import FlowRun, OrchestrationResult, TaskRun @@ -181,9 +181,13 @@ def enter_flow_run_engine_from_flow_call( # On completion of root flows, wait for the global thread to ensure that # any work there is complete - done_callbacks = ( - [create_call(wait_for_global_loop_exit)] if not is_subflow_run else None - ) + done_callbacks = [create_call(drain_global_loop)] if not is_subflow_run else None + + # On async cancellation, ensure that cancellation is forwarded to the global loop + # otherwise we can deadlock on drain + cancel_callbacks = [create_call(begin_run.cancel)] + if not is_subflow_run: + cancel_callbacks.append(create_call(cancel_global_loop)) # WARNING: You must define any context managers here to pass to our concurrency # api instead of entering them in here in the engine entrypoint. Otherwise, async @@ -202,6 +206,9 @@ def enter_flow_run_engine_from_flow_call( begin_run, done_callbacks=done_callbacks, contexts=contexts, + # As a special case, on async cancellation cancel all remaining work in the + # global loop to prevent deadlock on drain + cancel_callbacks=cancel_callbacks, ) else: diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index 22b026b3f480..bf3aebdabdad 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -396,6 +396,7 @@ async def enter_client(context): assert startup.call_count == shutdown.call_count assert startup.call_count > 0 + @pytest.mark.skipif(os.environ.get("CI") is not None, reason="Too slow for CI") async def test_client_context_lifespan_is_robust_to_high_async_concurrency(self): startup, shutdown = MagicMock(), MagicMock() app = FastAPI(lifespan=make_lifespan(startup, shutdown)) diff --git a/tests/test_engine.py b/tests/test_engine.py index b64da2dd55e1..ceae010ea0f9 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -283,9 +283,21 @@ async def flow_resumer(): assert len(task_runs) == 5, "all tasks should finish running" +async def add_deployment_id_to_flow_run(db, deployment, flow_run_id): + from prefect.server.models.flow_runs import update_flow_run + + async with db.session_context(begin_transaction=True) as session: + await update_flow_run( + session, + flow_run_id, + FlowRun.construct(deployment_id=deployment.id), + ) + await session.commit() + + class TestNonblockingPause: async def test_paused_flows_do_not_block_execution_with_reschedule_flag( - self, prefect_client, deployment, session + self, prefect_client, deployment, db ): flow_run_id = None @@ -298,15 +310,7 @@ async def pausing_flow_without_blocking(): nonlocal flow_run_id flow_run_id = get_run_context().flow_run.id - # Add the deployment id to the flow run to allow a pause - from prefect.server.models.flow_runs import update_flow_run - - await update_flow_run( - session, - flow_run_id, - FlowRun.construct(deployment_id=deployment.id), - ) - await session.commit() + await add_deployment_id_to_flow_run(db, deployment, flow_run_id) x = await foo.submit() y = await foo.submit() @@ -327,7 +331,7 @@ async def pausing_flow_without_blocking(): assert len(task_runs) == 2, "only two tasks should have completed" async def test_paused_flows_gracefully_exit_with_reschedule_flag( - self, session, deployment + self, db, deployment ): @task async def foo(): @@ -335,15 +339,9 @@ async def foo(): @flow(task_runner=SequentialTaskRunner()) async def pausing_flow_without_blocking(): - # Add the deployment id to the flow run to allow a pause - from prefect.server.models.flow_runs import update_flow_run - - await update_flow_run( - session, - prefect.runtime.flow_run.id, - FlowRun.construct(deployment_id=deployment.id), + await add_deployment_id_to_flow_run( + db, deployment, prefect.runtime.flow_run.id ) - await session.commit() x = await foo.submit() y = await foo.submit() @@ -356,7 +354,7 @@ async def pausing_flow_without_blocking(): await pausing_flow_without_blocking() async def test_paused_flows_can_be_resumed_then_rescheduled( - self, prefect_client, deployment, session + self, prefect_client, deployment, db ): flow_run_id = None @@ -369,15 +367,7 @@ async def pausing_flow_without_blocking(): nonlocal flow_run_id flow_run_id = get_run_context().flow_run.id - # Add the deployment id to the flow run to allow a pause - from prefect.server.models.flow_runs import update_flow_run - - await update_flow_run( - session, - flow_run_id, - FlowRun.construct(deployment_id=deployment.id), - ) - await session.commit() + await add_deployment_id_to_flow_run(db, deployment, flow_run_id) x = await foo.submit() y = await foo.submit() @@ -396,9 +386,7 @@ async def pausing_flow_without_blocking(): flow_run = await prefect_client.read_flow_run(flow_run_id) assert flow_run.state.is_scheduled() - async def test_subflows_cannot_be_paused_with_reschedule_flag( - self, deployment, session - ): + async def test_subflows_cannot_be_paused_with_reschedule_flag(self, deployment, db): @task async def foo(): return 42 @@ -443,7 +431,7 @@ async def pausing_flow_without_blocking(): class TestOutOfProcessPause: async def test_flows_can_be_paused_out_of_process( - self, prefect_client, deployment, session + self, prefect_client, deployment, db ): @task async def foo(): @@ -455,15 +443,9 @@ async def foo(): @flow(task_runner=SequentialTaskRunner()) async def pausing_flow_without_blocking(): - # Add the deployment id to the flow run to allow a pause - from prefect.server.models.flow_runs import update_flow_run - - await update_flow_run( - session, - prefect.runtime.flow_run.id, - FlowRun.construct(deployment_id=deployment.id), + await add_deployment_id_to_flow_run( + db, deployment, prefect.runtime.flow_run.id ) - await session.commit() context = FlowRunContext.get() x = await foo.submit() @@ -491,22 +473,16 @@ async def pausing_flow_without_blocking(): len(paused_task_runs) == 1 ), "one task run should have exited with a paused state" - async def test_out_of_process_pauses_exit_gracefully(self, deployment, session): + async def test_out_of_process_pauses_exit_gracefully(self, deployment, db): @task async def foo(): return 42 @flow(task_runner=SequentialTaskRunner()) async def pausing_flow_without_blocking(): - # Add the deployment id to the flow run to allow a pause - from prefect.server.models.flow_runs import update_flow_run - - await update_flow_run( - session, - prefect.runtime.flow_run.id, - FlowRun.construct(deployment_id=deployment.id), + await add_deployment_id_to_flow_run( + db, deployment, prefect.runtime.flow_run.id ) - await session.commit() context = FlowRunContext.get() x = await foo.submit()