diff --git a/scripts/openai_server_demo/README.md b/scripts/openai_server_demo/README.md index e5dbfe9..5ef1171 100644 --- a/scripts/openai_server_demo/README.md +++ b/scripts/openai_server_demo/README.md @@ -8,7 +8,7 @@ 安装依赖 ``` shell -pip install fastapi uvicorn shortuuid +pip install fastapi uvicorn shortuuid sse_starlette ``` 启动脚本 diff --git a/scripts/openai_server_demo/openai_api_protocol.py b/scripts/openai_server_demo/openai_api_protocol.py index 36ed1b1..55b0239 100644 --- a/scripts/openai_server_demo/openai_api_protocol.py +++ b/scripts/openai_server_demo/openai_api_protocol.py @@ -1,10 +1,11 @@ -from typing import Optional, List, Dict, Any, Union +from typing import Optional, List, Dict, Any, Union, Literal import time import shortuuid from pydantic import BaseModel, Field + class ChatCompletionRequest(BaseModel): model: str = "chinese-llama-alpaca" messages: Union[str, List[Dict[str, str]]] @@ -26,17 +27,30 @@ class ChatMessage(BaseModel): content: str +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str = "chinese-llama-alpaca" - choices: List[ChatCompletionResponseChoice] + choices: List[ + Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] + ] class EmbeddingsRequest(BaseModel): @@ -76,6 +90,5 @@ class CompletionResponse(BaseModel): id: Optional[str] = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") object: Optional[str] = "text_completion" created: Optional[int] = Field(default_factory=lambda: int(time.time())) - model: Optional[str] = 'chinese-llama-alpaca' + model: Optional[str] = "chinese-llama-alpaca" choices: List[CompletionResponseChoice] - diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 5d6072d..23de11d 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -2,15 +2,28 @@ import os from fastapi import FastAPI import uvicorn - +from threading import Thread +from sse_starlette.sse import EventSourceResponse parser = argparse.ArgumentParser() -parser.add_argument('--base_model', default=None, type=str, required=True) -parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model") -parser.add_argument('--tokenizer_path',default=None,type=str) -parser.add_argument('--gpus', default="0", type=str) -parser.add_argument('--load_in_8bit',action='store_true', help='use 8 bit model') -parser.add_argument('--only_cpu',action='store_true',help='only use CPU for inference') -parser.add_argument('--alpha',type=str,default="1.0", help="The scaling factor of NTK method, can be a float or 'auto'. ") +parser.add_argument("--base_model", default=None, type=str, required=True) +parser.add_argument( + "--lora_model", + default=None, + type=str, + help="If None, perform inference on the base model", +) +parser.add_argument("--tokenizer_path", default=None, type=str) +parser.add_argument("--gpus", default="0", type=str) +parser.add_argument("--load_in_8bit", action="store_true", help="use 8 bit model") +parser.add_argument( + "--only_cpu", action="store_true", help="only use CPU for inference" +) +parser.add_argument( + "--alpha", + type=str, + default="1.0", + help="The scaling factor of NTK method, can be a float or 'auto'. ", +) args = parser.parse_args() load_in_8bit = args.load_in_8bit if args.only_cpu is True: @@ -19,10 +32,16 @@ import torch import torch.nn.functional as F -from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig +from transformers import ( + LlamaForCausalLM, + LlamaTokenizer, + GenerationConfig, + TextIteratorStreamer, +) from peft import PeftModel from patches import apply_attention_patch, apply_ntk_scaling_patch + apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) @@ -36,13 +55,15 @@ CompletionResponseChoice, EmbeddingsRequest, EmbeddingsResponse, + ChatCompletionResponseStreamChoice, + DeltaMessage, ) load_type = torch.float16 if torch.cuda.is_available(): device = torch.device(0) else: - device = torch.device('cpu') + device = torch.device("cpu") if args.tokenizer_path is None: args.tokenizer_path = args.lora_model if args.lora_model is None: @@ -54,28 +75,34 @@ load_in_8bit=load_in_8bit, torch_dtype=load_type, low_cpu_mem_usage=True, - device_map='auto' if not args.only_cpu else None, - ) + device_map="auto" if not args.only_cpu else None, +) model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenzier_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") -if model_vocab_size!=tokenzier_vocab_size: +if model_vocab_size != tokenzier_vocab_size: assert tokenzier_vocab_size > model_vocab_size print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenzier_vocab_size) if args.lora_model is not None: print("loading peft model") - model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',) + model = PeftModel.from_pretrained( + base_model, + args.lora_model, + torch_dtype=load_type, + device_map="auto", + ) else: model = base_model -if device==torch.device('cpu'): +if device == torch.device("cpu"): model.float() model.eval() + def generate_completion_prompt(instruction: str): """Generate prompt for completion""" return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. @@ -85,23 +112,25 @@ def generate_completion_prompt(instruction: str): ### Response: """ + def generate_chat_prompt(messages: list): """Generate prompt for chat completion""" - system_msg = '''Below is an instruction that describes a task. Write a response that appropriately completes the request.''' + system_msg = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" for msg in messages: - if msg.role == 'system': + if msg.role == "system": system_msg = msg.content prompt = f"{system_msg}\n\n" for msg in messages: - if msg.role == 'system': + if msg.role == "system": continue - if msg.role == 'assistant': + if msg.role == "assistant": prompt += f"### Response: {msg.content}\n\n" - if msg.role == 'user': + if msg.role == "user": prompt += f"### Instruction:\n{msg.content}\n\n" prompt += "### Response: " return prompt + def predict( input, max_new_tokens=128, @@ -146,19 +175,86 @@ def predict( output = output.split("### Response:")[-1].strip() return output + +def stream_predict( + input, + max_new_tokens=128, + top_p=0.75, + temperature=0.1, + top_k=40, + num_beams=4, + repetition_penalty=1.0, + do_sample=True, + model_id="chinese-llama-alpaca", + **kwargs, +): + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(role="assistant"), finish_reason=None + ) + chunk = ChatCompletionResponse( + model=model_id, + choices=[choice_data], + object="chat.completion.chunk", + ) + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + + if isinstance(input, str): + prompt = generate_completion_prompt(input) + else: + prompt = generate_chat_prompt(input) + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + generation_config = GenerationConfig( + temperature=temperature, + top_p=top_p, + top_k=top_k, + num_beams=num_beams, + do_sample=do_sample, + **kwargs, + ) + + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = dict( + streamer=streamer, + input_ids=input_ids, + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=False, + max_new_tokens=max_new_tokens, + repetition_penalty=float(repetition_penalty), + ) + Thread(target=model.generate, kwargs=generation_kwargs).start() + for new_text in streamer: + if new_text.startswith(""): + continue + if new_text.endswith(""): + new_text = new_text.split("")[0] + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(content=new_text), finish_reason=None + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" + ) + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(), finish_reason="stop" + ) + chunk = ChatCompletionResponse( + model=model_id, choices=[choice_data], object="chat.completion.chunk" + ) + yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "[DONE]" + + def get_embedding(input): """Get embedding main function""" with torch.no_grad(): if tokenizer.pad_token == None: - tokenizer.add_special_tokens({'pad_token': '[PAD]'}) - encoding = tokenizer( - input, padding=True, return_tensors="pt" - ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + encoding = tokenizer(input, padding=True, return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) - model_output = model( - input_ids, attention_mask, output_hidden_states=True - ) + model_output = model(input_ids, attention_mask, output_hidden_states=True) data = model_output.hidden_states[-1] mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask @@ -169,16 +265,30 @@ def get_embedding(input): ret = normalized_embeddings.squeeze(0).tolist() return ret + app = FastAPI() + @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """Creates a completion for the chat message""" msgs = request.messages if isinstance(msgs, str): - msgs = [ChatMessage(role='user',content=msgs)] + msgs = [ChatMessage(role="user", content=msgs)] else: - msgs = [ChatMessage(role=x['role'],content=x['message']) for x in msgs] + msgs = [ChatMessage(role=x["role"], content=x["message"]) for x in msgs] + if request.stream: + generate = stream_predict( + input=msgs, + max_new_tokens=request.max_tokens, + top_p=request.top_p, + top_k=request.top_k, + temperature=request.temperature, + num_beams=request.num_beams, + repetition_penalty=request.repetition_penalty, + do_sample=request.do_sample, + ) + return EventSourceResponse(generate, media_type="text/event-stream") output = predict( input=msgs, max_new_tokens=request.max_tokens, @@ -189,9 +299,16 @@ async def create_chat_completion(request: ChatCompletionRequest): repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) - choices = [ChatCompletionResponseChoice(index = i, message = msg) for i, msg in enumerate(msgs)] - choices += [ChatCompletionResponseChoice(index = len(choices), message = ChatMessage(role='assistant',content=output))] - return ChatCompletionResponse(choices = choices) + choices = [ + ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs) + ] + choices += [ + ChatCompletionResponseChoice( + index=len(choices), message=ChatMessage(role="assistant", content=output) + ) + ] + return ChatCompletionResponse(choices=choices) + @app.post("/v1/completions") async def create_completion(request: CompletionRequest): @@ -206,23 +323,24 @@ async def create_completion(request: CompletionRequest): repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) - choices = [CompletionResponseChoice(index = 0, text = output)] - return CompletionResponse(choices = choices) + choices = [CompletionResponseChoice(index=0, text=output)] + return CompletionResponse(choices=choices) + @app.post("/v1/embeddings") async def create_embeddings(request: EmbeddingsRequest): """Creates text embedding""" embedding = get_embedding(request.input) - data = [{ - "object": "embedding", - "embedding": embedding, - "index": 0 - }] + data = [{"object": "embedding", "embedding": embedding, "index": 0}] return EmbeddingsResponse(data=data) if __name__ == "__main__": log_config = uvicorn.config.LOGGING_CONFIG - log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" - log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" - uvicorn.run(app, host='0.0.0.0', port=19327, workers=1, log_config=log_config) + log_config["formatters"]["access"][ + "fmt" + ] = "%(asctime)s - %(levelname)s - %(message)s" + log_config["formatters"]["default"][ + "fmt" + ] = "%(asctime)s - %(levelname)s - %(message)s" + uvicorn.run(app, host="0.0.0.0", port=19327, workers=1, log_config=log_config)