Skip to content

Commit

Permalink
add HealthCheck endpoint, enable reflection
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro committed Jun 13, 2024
1 parent 6206c1c commit b685d67
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 13 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ dependencies = [
"vllm>=0.4.3",
"prometheus_client==0.20.0",
"grpcio==1.62.1",
"grpcio-health-checking==1.62.1",
"grpcio-reflection==1.62.1",
"transformers",
"accelerate==0.28.0",
"hf-transfer==0.1.6"
Expand All @@ -41,7 +43,8 @@ Source = "https://github.com/dtrifiro/vllm_tgis_adapter"
[project.optional-dependencies]
tests = [
"pytest==8.2.0",
"pytest-cov==5.0.0"
"pytest-cov==5.0.0",
"pytest-mock==3.14.0"
]
dev = [
"vllm_tgis_adapter[tests]",
Expand Down Expand Up @@ -144,7 +147,8 @@ ignore = [
"FIX", # todo messages
"ERA001", # commented out code
# formatting
"COM812"
"COM812",
"RET504" # Unnecessary assignment to `args` before `return` statement
]
select = ["ALL"]

Expand Down
43 changes: 32 additions & 11 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import grpc
from grpc import StatusCode, aio
from grpc._cython.cygrpc import AbortError
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
Expand All @@ -31,6 +33,7 @@
)

from .pb import generation_pb2_grpc
from .pb.generation_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR
from .pb.generation_pb2 import (
BatchedGenerationResponse,
BatchedTokenizeResponse,
Expand Down Expand Up @@ -137,7 +140,16 @@ async def func_with_log(*args, **kwargs): # noqa: ANN002,ANN003,ANN202


class TextGenerationService(generation_pb2_grpc.GenerationServiceServicer):
def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
SERVICE_NAME = _GENERATION_DESCRIPTOR.services_by_name[
"GenerationService"
].full_name

def __init__(
self,
engine: AsyncLLMEngine,
args: argparse.Namespace,
health_servicer: health.HealthServicer,
):
self.engine: AsyncLLMEngine = engine

# These are set in post_init()
Expand All @@ -148,6 +160,8 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

self.health_servicer = health_servicer

@property
def tokenizer_group(self) -> BaseTokenizerGroup:
assert hasattr(self.engine.engine, "tokenizer")
Expand All @@ -173,6 +187,11 @@ async def post_init(self) -> None:
# 🌶️🌶️🌶️ sneaky sneak
self.engine.engine.stat_logger = tgis_stats_logger

self.health_servicer.set(
self.SERVICE_NAME,
health_pb2.HealthCheckResponse.SERVING,
)

@log_rpc_handler_errors
async def Generate(
self,
Expand Down Expand Up @@ -727,19 +746,21 @@ async def start_grpc_server(
logger.info(memory_summary(engine.engine.device_config.device))

server = aio.server()
service = TextGenerationService(engine, args)
await service.post_init()

generation_pb2_grpc.add_GenerationServiceServicer_to_server(service, server)
health_servicer = health.HealthServicer()
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)

generation = TextGenerationService(engine, args, health_servicer)
await generation.post_init()
generation_pb2_grpc.add_GenerationServiceServicer_to_server(generation, server)

# TODO add reflection
service_names = (
health.SERVICE_NAME,
generation.SERVICE_NAME,
reflection.SERVICE_NAME,
)

# SERVICE_NAMES = (
# generation_pb2.DESCRIPTOR.services_by_name["GenerationService"]
# .full_name,
# reflection.SERVICE_NAME,
# )
# reflection.enable_server_reflection(SERVICE_NAMES, server)
reflection.enable_server_reflection(service_names, server)

host = "0.0.0.0" if args.host is None else args.host # noqa: S104
listen_on = f"{host}:{args.grpc_port}"
Expand Down
112 changes: 112 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import asyncio
import sys
import threading

import pytest
from grpc_health.v1.health_pb2 import HealthCheckRequest
from grpc_health.v1.health_pb2_grpc import Health
from vllm import AsyncLLMEngine
from vllm.config import DeviceConfig
from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser

from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService, start_grpc_server
from vllm_tgis_adapter.tgis_utils.args import (
EnvVarArgumentParser,
add_tgis_args,
postprocess_tgis_args,
)

from .utils import get_random_port, wait_until


@pytest.fixture()
def engine(mocker, monkeypatch):
"""Return a mocked vLLM engine."""
engine = mocker.Mock(spec=AsyncLLMEngine)
engine.engine = mocker.Mock(spec=_AsyncLLMEngine)
mocker.patch("torch.cuda.memory_summary", return_value="mocked")
engine.engine.device_config = mocker.Mock(spec=DeviceConfig)
engine.engine.device_config.device = "cuda"

engine.engine.stat_logger = "mocked"

return engine


@pytest.fixture()
def args(monkeypatch, grpc_server_thread_port):
# avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments
monkeypatch.setattr(
sys,
"argv",
[
"__main__.py",
f"--grpc-port={grpc_server_thread_port}",
],
)

parser = EnvVarArgumentParser(parser=make_arg_parser())
parser = add_tgis_args(parser)
args = postprocess_tgis_args(parser.parse_args())

return args


@pytest.fixture()
def grpc_server_thread_port():
"""Port for grpc server."""
return get_random_port()


@pytest.fixture()
def grpc_server_url(grpc_server_thread_port):
"""Port for grpc server."""
return f"localhost:{grpc_server_thread_port}"


@pytest.fixture()
def grpc_server(engine, args, grpc_server_url):
"""Spins up grpc server in a background thread."""

def health_check():
print("Waiting for server to be up...") # noqa: T201
request = HealthCheckRequest(service=TextGenerationService.SERVICE_NAME)
resp = Health.Check(
request=request,
target=grpc_server_url,
timeout=1,
insecure=True,
)
assert resp.status == resp.SERVING
print("Server is up") # noqa: T201

global server # noqa: PLW0602

loop = asyncio.new_event_loop()

async def run_server():
global server # noqa: PLW0603

server = await start_grpc_server(engine, args)
while server._server.is_running(): # noqa: SLF001
await asyncio.sleep(1)

def target():
loop.run_until_complete(run_server())

t = threading.Thread(target=target)
t.start()

async def stop():
global server # noqa: PLW0602

await server.stop(grace=None)
await server.wait_for_termination()

try:
wait_until(_health_check)
yield server
finally:
loop.create_task(stop()) # noqa: RUF006
t.join()
3 changes: 3 additions & 0 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def test_startup(grpc_server):
"""Test that the grpc_server fixture starts up properly."""
assert grpc_server._server.is_running() # noqa: SLF001
34 changes: 34 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import socket
import time
from contextlib import closing
from typing import Callable, TypeVar

_T = TypeVar("_T")


def get_random_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
return port


def wait_until(
pred: Callable[..., _T],
timeout: float = 30,
pause: float = 0.5,
) -> _T:
start = time.perf_counter()
exc = None

while (time.perf_counter() - start) < timeout:
try:
value = pred()
except Exception as e: # noqa: BLE001
exc = e
else:
return value
time.sleep(pause)

raise TimeoutError("timed out waiting") from exc

0 comments on commit b685d67

Please sign in to comment.