Skip to content

Commit

Permalink
Support vllm for openai api, refactor openai non-streaming api (#86)
Browse files Browse the repository at this point in the history
* add vllm support for openai mode

* fix streaming response

* refactor openai call function

* fix tokens' length

* fix tokens' length for vllm

* run ci

* modify

* remove return_shape param, add data class

* modify

* modify

* address comments

* modify

* minimal modify after testing again
  • Loading branch information
KepingYan authored Feb 5, 2024
1 parent 8286f68 commit a3ca834
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 148 deletions.
34 changes: 19 additions & 15 deletions examples/inference/api_server_openai/query_http_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,24 @@
}

proxies = {"http": None, "https": None}
response = s.post(url, json=body, proxies=proxies) # type: ignore
response = s.post(url, json=body, proxies=proxies, stream=args.streaming_response) # type: ignore
for chunk in response.iter_lines(decode_unicode=True):
if chunk is not None:
if args.streaming_response:
# Get data from reponse chunk
chunk_data = chunk.split("data: ")[1]
if chunk_data != "[DONE]":
# Get message choices from data
choices = json.loads(chunk_data)["choices"]
# Pick content from first choice
content = choices[0]["delta"].get("content", "")
try:
if chunk is not None:
if args.streaming_response:
# Get data from reponse chunk
chunk_data = chunk.split("data: ")[1]
if chunk_data != "[DONE]":
# Get message choices from data
choices = json.loads(chunk_data)["choices"]
# Pick content from first choice
content = choices[0]["delta"].get("content", "")
print(content, end="", flush=True)
else:
choices = json.loads(chunk)["choices"]
content = choices[0]["message"].get("content", "")
print(content, end="", flush=True)
else:
choices = json.loads(chunk)["choices"]
content = choices[0]["message"].get("content", "")
print(content, end="", flush=True)
print("")
except Exception as e:
print("chunk content: ", chunk)
raise e
print()
17 changes: 7 additions & 10 deletions examples/inference/api_server_simple/query_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,10 @@
stream=args.streaming_response,
)

try:
outputs.raise_for_status()
if args.streaming_response:
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
print(output, end="", flush=True)
print()
else:
print(outputs.text, flush=True)
except Exception as e:
print(e)
outputs.raise_for_status()
if args.streaming_response:
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
print(output, end="", flush=True)
print()
else:
print(outputs.text, flush=True)
17 changes: 4 additions & 13 deletions inference/api_openai_backend/query_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,15 @@

from typing import Dict
from fastapi import HTTPException
from .openai_protocol import ModelCard, Prompt, ModelResponse
from .openai_protocol import ModelCard, Prompt
from .request_handler import handle_request


class RouterQueryClient:
def __init__(self, serve_deployments):
self.serve_deployments = serve_deployments

async def query(self, model: str, prompt: Prompt, request_id: str):
response_stream = self.stream(
model,
prompt,
request_id,
)
responses = [resp async for resp in response_stream]
return ModelResponse.merge_stream(*responses)

async def stream(self, model: str, prompt: Prompt, request_id: str):
async def query(self, model: str, prompt: Prompt, request_id: str, streaming_reponse: bool):
if model in self.serve_deployments:
deploy_handle = self.serve_deployments[model]
else:
Expand All @@ -75,8 +66,8 @@ async def stream(self, model: str, prompt: Prompt, request_id: str):
prompt=prompt,
request_id=request_id,
async_iterator=deploy_handle.options(stream=True)
.stream_response.options(stream=True, use_new_handle_api=True)
.remote(prompt_content, gen_config),
.openai_call.options(stream=True, use_new_handle_api=True)
.remote(prompt_content, gen_config, streaming_response=streaming_reponse),
):
yield x

Expand Down
4 changes: 2 additions & 2 deletions inference/api_openai_backend/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import asyncio
import traceback
from typing import AsyncIterator, List
from fastapi import status, HTTPException
from fastapi import status, HTTPException, Request
from starlette.responses import JSONResponse
from pydantic import ValidationError as PydanticValidationError
from logger import get_logger
Expand All @@ -56,7 +56,7 @@ def __init__(
self.type = type


def openai_exception_handler(exc: OpenAIHTTPException):
def openai_exception_handler(r: Request, exc: OpenAIHTTPException):
assert isinstance(exc, OpenAIHTTPException), f"Unable to handle invalid exception {type(exc)}"
if exc.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR:
message = "Internal Server Error"
Expand Down
129 changes: 69 additions & 60 deletions inference/api_openai_backend/router_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,6 @@ async def _completions_wrapper(
)
]
usage = None
if subresult_dict["finish_reason"]:
usage = (
UsageInfo.from_response(ModelResponse.merge_stream(*all_results))
if all_results
else None
)
yield "data: " + CompletionResponse(
id=completion_id,
object="text_completion",
Expand All @@ -135,6 +129,19 @@ async def _completions_wrapper(
if had_error:
# Return early in case of an error
break
if not had_error:
usage = (
UsageInfo.from_response(ModelResponse.merge_stream(*all_results))
if all_results
else None
)
yield "data: " + CompletionResponse(
id=completion_id,
object="text_completion",
model=body.model,
choices=choices,
usage=usage,
).json() + "\n"
yield "data: [DONE]\n"


Expand Down Expand Up @@ -275,41 +282,40 @@ async def completions(
request_id,
body,
response,
self.query_client.stream(
body.model,
prompt,
request_id,
),
self.query_client.query(body.model, prompt, request_id, body.stream),
),
media_type="text/event-stream",
)
else:
async with async_timeout.timeout(TIMEOUT):
results = await self.query_client.query(body.model, prompt, request_id)
if results.error:
raise OpenAIHTTPException(
message=results.error.message,
status_code=results.error.code,
type=results.error.type,
)
results = results.dict()
results_reponse = self.query_client.query(
body.model, prompt, request_id, body.stream
)
async for results in results_reponse:
if results.error:
raise OpenAIHTTPException(
message=results.error.message,
status_code=results.error.code,
type=results.error.type,
)
results = results.dict()

choices = [
CompletionResponseChoice(
index=0,
text=results["generated_text"] or "",
finish_reason=results["finish_reason"],
)
]
usage = UsageInfo.from_response(results)
choices = [
CompletionResponseChoice(
index=0,
text=results["generated_text"] or "",
finish_reason=results["finish_reason"],
)
]
usage = UsageInfo.from_response(results)

return CompletionResponse(
id=request_id,
object="text_completion",
model=body.model,
choices=choices,
usage=usage,
)
return CompletionResponse(
id=request_id,
object="text_completion",
model=body.model,
choices=choices,
usage=usage,
)

@router_app.post("/v1/chat/completions")
async def chat(
Expand All @@ -332,39 +338,42 @@ async def chat(
request_id,
body,
response,
self.query_client.stream(body.model, prompt, request_id),
self.query_client.query(body.model, prompt, request_id, body.stream),
),
media_type="text/event-stream",
)
else:
async with async_timeout.timeout(TIMEOUT):
results = await self.query_client.query(body.model, prompt, request_id)
if results.error:
raise OpenAIHTTPException(
message=results.error.message,
status_code=results.error.code,
type=results.error.type,
)
results = results.dict()
results_reponse = self.query_client.query(
body.model, prompt, request_id, body.stream
)
async for results in results_reponse:
if results.error:
raise OpenAIHTTPException(
message=results.error.message,
status_code=results.error.code,
type=results.error.type,
)
results = results.dict()

choices: List[ChatCompletionResponseChoice] = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant", content=results["generated_text"] or ""
),
finish_reason=results["finish_reason"],
)
]
usage = UsageInfo.from_response(results)
choices: List[ChatCompletionResponseChoice] = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role="assistant", content=results["generated_text"] or ""
),
finish_reason=results["finish_reason"],
)
]
usage = UsageInfo.from_response(results)

return ChatCompletionResponse(
id=request_id,
object="chat.completion",
model=body.model,
choices=choices,
usage=usage,
)
return ChatCompletionResponse(
id=request_id,
object="chat.completion",
model=body.model,
choices=choices,
usage=usage,
)

@router_app.get("/v1/health_check")
async def health_check(self) -> bool:
Expand Down
14 changes: 11 additions & 3 deletions inference/deepspeed_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from utils import get_torch_dtype
from inference.inference_config import (
InferenceConfig,
GenerateResult,
DEVICE_CPU,
DEVICE_XPU,
PRECISION_BF16,
Expand Down Expand Up @@ -232,19 +233,26 @@ def _init_worker_group(self, scaling_config: ScalingConfig):
)

def streaming_generate(self, prompt, streamer, **config):
input_ids = self.tokenize_inputs(prompt)
input_ids, _ = self.tokenize_inputs(prompt)
inputs_ref = ray.put(input_ids)
self.prediction_workers[0].streaming_generate.remote(inputs_ref, streamer, **config)
for worker in self.prediction_workers[1:]:
worker.streaming_generate.remote(inputs_ref, self._create_dummy_streamer(), **config)

def generate(self, prompt, **config):
input_ids = self.tokenize_inputs(prompt)
input_ids, input_length = self.tokenize_inputs(prompt)
inputs_ref = ray.put(input_ids)
gen_tokens = ray.get(
[worker.generate.remote(inputs_ref, **config) for worker in self.prediction_workers]
)[0]
return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
decode_result = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
if isinstance(prompt, list) and len(prompt) > 1:
return decode_result
return GenerateResult(
text=decode_result,
input_length=input_length,
generate_length=gen_tokens.size()[1] - input_length,
)

def get_streamer(self):
from transformers import TextStreamer
Expand Down
7 changes: 7 additions & 0 deletions inference/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def _check_load_in_low_bit(cls, v: str):
return v


# for non-streaming response
class GenerateResult(BaseModel):
text: Union[str, List[str]] = ""
input_length: Union[int, None] = None
generate_length: Union[int, None] = None


class ModelDescription(BaseModel):
model_id_or_path: Union[str, None] = None
bigdl: bool = False
Expand Down
10 changes: 7 additions & 3 deletions inference/predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import torch
from transformers import AutoTokenizer, StoppingCriteriaList
from inference.inference_config import InferenceConfig
from inference.inference_config import InferenceConfig, GenerateResult
from utils import StoppingCriteriaSub
from typing import List, AsyncGenerator, Union

Expand All @@ -28,10 +28,14 @@ def __init__(self, infer_conf: InferenceConfig) -> None:
for stop_word in stop_words
]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
self.input_length = None

def tokenize_inputs(self, text):
input_tokens = self.tokenizer(text, return_tensors="pt", padding=True)
return input_tokens.input_ids.to(device=self.device)
input_ids = input_tokens.input_ids
self.input_length = input_ids.size()[1]
input_ids = input_ids.to(device=self.device)
return input_ids, self.input_length

def configure_tokenizer(self, model_name):
model = self.model
Expand Down Expand Up @@ -73,7 +77,7 @@ def configure_tokenizer(self, model_name):
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

def generate(self, prompts: Union[str, List[str]], **config) -> Union[str, List[str]]:
def generate(self, prompts: Union[str, List[str]], **config) -> GenerateResult:
pass

async def generate_async(
Expand Down
Loading

0 comments on commit a3ca834

Please sign in to comment.