-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathentrypoint.py
90 lines (74 loc) · 2.73 KB
/
entrypoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import base64
import io
import logging
import os
import sys
import traceback
import uvicorn
from fastapi import FastAPI, HTTPException, Response, status
from PIL import Image
from sdxl_turbo import SdxlTurboRequest, setup_pipeline
MODEL_NAME = os.getenv("MODEL_NAME")
CACHED_MODEL_PATH = os.getenv('SAVE_PATH')
if MODEL_NAME is None or CACHED_MODEL_PATH is None:
logging.error("Environment variables MODEL_NAME and CACHED_MODEL_PATH must be set. See Dockerfile for values.")
sys.exit(1)
app = FastAPI()
@app.on_event("startup")
def load_model():
global pipe_t2i, pipe_i2i
pipe_t2i, pipe_i2i = setup_pipeline(MODEL_NAME, CACHED_MODEL_PATH)
logging.info("Sdxl Turbo model loaded.")
# Heartbeat endpoint
@app.get("/")
async def heartbeat():
return {"status": "alive"}
# sdxl turbo text to image endpoint
@app.post("/sdxl-turbo-t2i")
async def generate_t2i(request: SdxlTurboRequest):
try:
image = pipe_t2i(
prompt=request.prompt,
strength=request.strength,
guidance_scale=request.guidance_scale,
num_images_per_prompt=request.num_images_per_prompt,
num_inference_steps=request.num_inference_steps,
).images[0]
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return Response(
content=buffer.getvalue(),
status_code=status.HTTP_200_OK,
media_type="image/png",
)
except Exception as e:
logging.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
# sdxl turbo image to image endpoint
@app.post("/sdxl-turbo-i2i")
async def generate_i2i(request: SdxlTurboRequest):
try:
init_image = request.image
base64_decoded = base64.b64decode(init_image)
input_image = Image.frombytes("RGB", (512, 512), base64_decoded, "raw")
image = pipe_i2i(
image=input_image,
prompt=request.prompt,
strength=request.strength,
guidance_scale=request.guidance_scale,
num_images_per_prompt=request.num_images_per_prompt,
num_inference_steps=request.num_inference_steps,
).images[0]
with io.BytesIO() as buffer:
image.save(buffer, format="PNG")
return Response(
content=buffer.getvalue(),
status_code=status.HTTP_200_OK,
media_type="image/png",
)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
if __name__ == "__main__":
port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080
uvicorn.run("entrypoint:app", host="0.0.0.0", port=port)