Skip to content

Commit

Permalink
Replace zibai with uvicorn
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Jun 8, 2024
1 parent dc8c834 commit 46dae9b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"wandb>=0.15.11",
"grpcio>=1.58.0",
"kui>=1.6.0",
"zibai-server>=0.9.0",
"uvicorn>=0.30.0",
"loguru>=0.6.0",
"loralib>=0.1.2",
"natsort>=8.4.0",
Expand Down
44 changes: 23 additions & 21 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
import pyrootutils
import soundfile as sf
import torch
from kui.wsgi import (
from kui.asgi import (
Body,
FileResponse,
HTTPException,
HttpView,
JSONResponse,
Kui,
OpenAPI,
StreamResponse,
)
from kui.wsgi.routing import MultimethodRoutes
from kui.asgi.routing import MultimethodRoutes
from loguru import logger
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
Expand Down Expand Up @@ -57,7 +58,7 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):


# Define utils for web server
def http_execption_handler(exc: HTTPException):
async def http_execption_handler(exc: HTTPException):
return JSONResponse(
dict(
statusCode=exc.status_code,
Expand All @@ -69,7 +70,7 @@ def http_execption_handler(exc: HTTPException):
)


def other_exception_handler(exc: "Exception"):
async def other_exception_handler(exc: "Exception"):
traceback.print_exc()

status = HTTPStatus.INTERNAL_SERVER_ERROR
Expand Down Expand Up @@ -334,8 +335,17 @@ def inference(req: InvokeRequest):
yield fake_audios


async def inference_async(req: InvokeRequest):
for chunk in inference(req):
yield chunk


async def buffer_to_async_generator(buffer):
yield buffer


@routes.http.post("/v1/invoke")
def api_invoke_model(
async def api_invoke_model(
req: Annotated[InvokeRequest, Body(exclusive=True)],
):
"""
Expand All @@ -354,22 +364,21 @@ def api_invoke_model(
content="Streaming only supports WAV format",
)

generator = inference(req)
if req.streaming:
return StreamResponse(
iterable=generator,
iterable=inference_async(req),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
content_type=get_content_type(req.format),
)
else:
fake_audios = next(generator)
fake_audios = next(inference(req))
buffer = io.BytesIO()
sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)

return StreamResponse(
iterable=[buffer.getvalue()],
iterable=buffer_to_async_generator(buffer.getvalue()),
headers={
"Content-Disposition": f"attachment; filename=audio.{req.format}",
},
Expand All @@ -378,7 +387,7 @@ def api_invoke_model(


@routes.http.post("/v1/health")
def api_health():
async def api_health():
"""
Health check
"""
Expand Down Expand Up @@ -409,6 +418,7 @@ def parse_args():
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
parser.add_argument("--workers", type=int, default=1)

return parser.parse_args()

Expand All @@ -433,7 +443,7 @@ def parse_args():
if __name__ == "__main__":
import threading

from zibai import create_bind_socket, serve
import uvicorn

args = parse_args()
args.precision = torch.half if args.half else torch.bfloat16
Expand Down Expand Up @@ -480,13 +490,5 @@ def parse_args():
)

logger.info(f"Warming up done, starting server at http://{args.listen}")
sock = create_bind_socket(args.listen)
sock.listen()

# Start server
serve(
app=app,
bind_sockets=[sock],
max_workers=10,
graceful_exit=threading.Event(),
)
host, port = args.listen.split(":")
uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")

0 comments on commit 46dae9b

Please sign in to comment.