Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Evict requests if the client has disconnected #208

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cdff37b
chore: Add request_evicted_status to streaming loop to cancel requests
bhimrazy Aug 21, 2024
7564479
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 21, 2024
99da82b
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 21, 2024
15fe905
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 21, 2024
e5565a8
fix failing test
bhimrazy Aug 21, 2024
36429a5
fixed: cannot access local variable 'uid'
bhimrazy Aug 21, 2024
9c08744
feat: adds test for `test_stream_client_disconnection`
bhimrazy Aug 22, 2024
f5522fa
ref: format imports using ruff
bhimrazy Aug 22, 2024
2f46532
fix lint warning for `@pytest.mark.asyncio`
bhimrazy Aug 22, 2024
7aacee6
adds a todo in the test for reminder
bhimrazy Aug 22, 2024
4327e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
c054af3
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 22, 2024
6eeb90d
adds cleanup for the dict to prevent leakage
bhimrazy Aug 22, 2024
dc041d2
chore: fix typo in test_lit_server.py
bhimrazy Aug 22, 2024
18419f1
updates the sleep time
bhimrazy Aug 22, 2024
f6763e5
updated some time
bhimrazy Aug 22, 2024
6dc6454
updated prompt len
bhimrazy Aug 22, 2024
e7b3059
chore: Remove print statement in stream_predict method
bhimrazy Aug 22, 2024
b0be9ce
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
a9d86ce
Merge branch 'feat/evict-req-on-client-disconnect' of github.com:bhim…
bhimrazy Aug 23, 2024
34453e9
chore: Add delayed prediction support in LitAPI subclasses
bhimrazy Aug 23, 2024
0069b98
updated stream test and added test for nonstream case
bhimrazy Aug 23, 2024
f3d6bd2
added logic to handle the client disconnection in predict
bhimrazy Aug 23, 2024
6029165
update sleep duration
bhimrazy Aug 23, 2024
6e95b30
Update sleep duration
bhimrazy Aug 23, 2024
f6f3e4c
update sleep time
bhimrazy Aug 23, 2024
9d47245
removed sleep
bhimrazy Aug 23, 2024
86ca3ce
check if `is_disconnected` exists
bhimrazy Aug 23, 2024
154cc6c
adds sleep
bhimrazy Aug 23, 2024
39986bf
chore: Update sleep duration
bhimrazy Aug 23, 2024
2c7633a
chore: Update sleep duration in LitServer
bhimrazy Aug 23, 2024
ccaeee9
tried another approach to check & handle disconnection
bhimrazy Aug 23, 2024
f0b19af
wrap in try catch
bhimrazy Aug 23, 2024
dcab100
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
b810a66
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
4edab2c
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
3d9a9a7
Merge branch 'main' into feat/evict-req-on-client-disconnect
aniketmaurya Aug 24, 2024
5c0d7fc
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 25, 2024
919b304
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 26, 2024
6ffe51c
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 84 additions & 23 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def collate_requests(
return payloads, timed_out_uids


def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_single_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -146,6 +152,8 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re
lit_api.encode_response,
y,
)
# TODO: Cancel the task if the client disconnects

response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
except Exception as e:
logger.exception(
Expand Down Expand Up @@ -217,7 +225,13 @@ def run_batched_loop(
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -256,6 +270,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
y_gen,
)
for y_enc in y_enc_gen:
if request_evicted_status.get(uid):
request_evicted_status.pop(uid)
break
y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
Expand Down Expand Up @@ -338,6 +355,7 @@ def inference_worker(
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand All @@ -357,7 +375,7 @@ def inference_worker(
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status)
return

if max_batch_size > 1:
Expand All @@ -368,6 +386,7 @@ def inference_worker(
lit_spec,
request_queue,
response_queues,
request_evicted_status,
)


Expand Down Expand Up @@ -397,7 +416,7 @@ async def response_queue_to_buffer(
await asyncio.sleep(0.0001)
continue
q, event = buffer[uid]
q.append(payload)
q.append((uid, payload))
event.set()

else:
Expand Down Expand Up @@ -498,6 +517,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
manager = mp.Manager()
self.workers_setup_status = manager.dict()
self.request_queue = manager.Queue()
self.request_evicted_status = manager.dict()

self.response_queues = []
for _ in range(num_uvicorn_servers):
Expand Down Expand Up @@ -531,6 +551,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
worker_id,
self.request_queue,
self.response_queues,
self.request_evicted_status,
self.max_batch_size,
self.batch_timeout,
self.stream,
Expand Down Expand Up @@ -568,26 +589,37 @@ def device_identifiers(self, accelerator, device):
return [f"{accelerator}:{device}"]

async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False):
uid = None
while True:
await data_available.wait()
while len(q) > 0:
data, status = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
try:
await data_available.wait()
while len(q) > 0:
uid, (data, status) = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
if send_status:
yield data, status
return
if send_status:
yield data, status
return
if send_status:
yield data, status
else:
yield data
data_available.clear()
else:
yield data
data_available.clear()
except asyncio.CancelledError:
if uid is not None:
self.request_evicted_status[uid] = True
logger.error("Request evicted for the uid=%s", uid)
break
except Exception as e:
# Handle other exceptions that might occur
logger.error(f"Exception occurred during streaming: {e}")
break

def setup_server(self):
workers_ready = False
Expand Down Expand Up @@ -625,8 +657,37 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks)

self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload))

await event.wait()
response, status = self.response_buffer.pop(uid)
async def wait_for_response():
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
await event.wait()
return self.response_buffer.pop(uid)

async def check_disconnection():
while True:
if hasattr(request, "is_disconnected") and await request.is_disconnected():
return True
await asyncio.sleep(1) # Check every second

response_task = asyncio.create_task(wait_for_response())
disconnection_task = asyncio.create_task(check_disconnection())

try:
# Use asyncio.wait to handle both response and disconnection checks
done, pending = await asyncio.wait(
[response_task, disconnection_task], return_when=asyncio.FIRST_COMPLETED
)
if response_task in done:
response, status = await response_task
disconnection_task.cancel()
else:
response_task.cancel()
logger.error(f"Client disconnected for the request uid={uid}")
self.request_evicted_status[uid] = True
raise HTTPException(status_code=499, detail="Client closed request")
except asyncio.CancelledError:
response_task.cancel()
disconnection_task.cancel()
logger.error(f"Client disconnected for the request uid={uid}")
raise HTTPException(status_code=499, detail="Client closed request")

if status == LitAPIStatus.ERROR:
load_and_raise(response)
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def encode_response(self, output) -> Response:
return {"output": output}


class SimpleDelayedLitAPI(SimpleLitAPI):
def predict(self, x):
time.sleep(0.5)
return self.model(x)


class SimpleStreamAPI(LitAPI):
def setup(self, device) -> None:
self.sentence = "LitServe is streaming output"
Expand All @@ -55,6 +61,14 @@ def encode_response(self, output: Generator) -> Generator:
yield out.lower()


class SimpleDelayedStreamAPI(SimpleStreamAPI):
def encode_response(self, output: Generator) -> Generator:
delay = 0.2
for out in output:
time.sleep(delay)
yield out.lower()


class SimpleBatchedStreamAPI(LitAPI):
def setup(self, device) -> None:
self.sentence = "LitServe is streaming output"
Expand Down Expand Up @@ -88,11 +102,21 @@ def simple_litapi():
return SimpleLitAPI()


@pytest.fixture()
def simple_delayed_litapi():
return SimpleDelayedLitAPI()


@pytest.fixture()
def simple_stream_api():
return SimpleStreamAPI()


@pytest.fixture()
def simple_delayed_stream_api():
return SimpleDelayedStreamAPI()


@pytest.fixture()
def simple_batched_stream_api():
return SimpleBatchedStreamAPI()
Expand Down
76 changes: 58 additions & 18 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,33 @@
# limitations under the License.
import asyncio
import inspect
import logging
import pickle
import re
from asgi_lifespan import LifespanManager
from litserve import LitAPI
from fastapi import Request, Response, HTTPException
import time
import torch
import torch.nn as nn
from queue import Queue
from httpx import AsyncClient
from litserve.utils import wrap_litserve_start
from unittest.mock import MagicMock, patch

from unittest.mock import patch, MagicMock
import pytest
import torch
import torch.nn as nn
from asgi_lifespan import LifespanManager
from fastapi import HTTPException, Request, Response
from fastapi.testclient import TestClient
from httpx import AsyncClient

import litserve as ls
from litserve import LitAPI
from litserve.connector import _Connector
from litserve.server import (
LitAPIStatus,
LitServer,
inference_worker,
run_batched_streaming_loop,
run_single_loop,
run_streaming_loop,
LitAPIStatus,
run_batched_streaming_loop,
)
from litserve.server import LitServer
import litserve as ls
from fastapi.testclient import TestClient
from litserve.utils import wrap_litserve_start


def test_index(sync_testclient):
Expand Down Expand Up @@ -66,10 +67,10 @@ def test_device_identifiers(lifespan_mock, simple_litapi):
@patch("litserve.server.run_batched_loop")
@patch("litserve.server.run_single_loop")
def test_inference_worker(mock_single_loop, mock_batched_loop):
inference_worker(*[MagicMock()] * 6, max_batch_size=2, batch_timeout=0, stream=False)
inference_worker(*[MagicMock()] * 7, max_batch_size=2, batch_timeout=0, stream=False)
mock_batched_loop.assert_called_once()

inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False)
inference_worker(*[MagicMock()] * 7, max_batch_size=1, batch_timeout=0, stream=False)
mock_single_loop.assert_called_once()


Expand All @@ -94,9 +95,9 @@ def test_single_loop(loop_args):
lit_api_mock, requests_queue = loop_args
lit_api_mock.unbatch.side_effect = None
response_queues = [FakeResponseQueue()]

request_evicted_status = {}
with pytest.raises(StopIteration, match="exit loop"):
run_single_loop(lit_api_mock, None, requests_queue, response_queues)
run_single_loop(lit_api_mock, None, requests_queue, response_queues, request_evicted_status)


@pytest.mark.asyncio()
Expand All @@ -120,6 +121,44 @@ async def test_stream(simple_stream_api):
), "Server returns input prompt and generated output which didn't match."


@pytest.mark.asyncio()
async def test_client_disconnection(simple_delayed_litapi, caplog):
server = LitServer(simple_delayed_litapi, timeout=10)

with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
task = asyncio.create_task(ac.post("/predict", json={"input": 1.0}, timeout=10))
await asyncio.sleep(0.2)
task.cancel()
await asyncio.sleep(1)
assert "Client disconnected for the request uid" in caplog.text
# TODO: also check if the task actually stopped in the server

caplog.clear()
task = asyncio.create_task(ac.post("/predict", json={"input": 1.0}, timeout=10))
await task
assert "Client disconnected for the request uid" not in caplog.text


@pytest.mark.asyncio()
async def test_stream_client_disconnection(simple_delayed_stream_api, caplog):
server = LitServer(simple_delayed_stream_api, stream=True, timeout=10)

with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 5}, timeout=10))
await asyncio.sleep(2)
task.cancel() # simulate client disconnection
await asyncio.sleep(1) # wait for the task to stop
assert "Request evicted for the uid=" in caplog.text
# TODO: also check if the task actually stopped in the server

caplog.clear()
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10))
await task
assert "Request evicted for the uid=" not in caplog.text


@pytest.mark.asyncio()
async def test_batched_stream_server(simple_batched_stream_api):
server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30)
Expand Down Expand Up @@ -175,11 +214,12 @@ def fake_encode(output):
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)

requests_queue = Queue()
request_evicted_status = {}
requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"}))
response_queues = [FakeStreamResponseQueue(num_streamed_outputs)]

with pytest.raises(StopIteration, match="exit loop"):
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues)
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues, request_evicted_status)

fake_stream_api.predict.assert_called_once_with("Hello")
fake_stream_api.encode_response.assert_called_once()
Expand Down
Loading