Skip to content

Commit

Permalink
write to /dev/termination-log on main loop exception.
Browse files Browse the repository at this point in the history
This includes picking up server config errors, but does NOT
make any attempt in recovering exception stacks happening
inside vllm's RPC server, as they're raised in a separate
process, other than reporting a RuntimeError.
  • Loading branch information
NickLucche authored and dtrifiro committed Sep 17, 2024
1 parent 047f767 commit 355a088
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 9 deletions.
34 changes: 32 additions & 2 deletions src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import asyncio
import contextlib
import os
import traceback
from concurrent.futures import FIRST_COMPLETED
from typing import TYPE_CHECKING

Expand All @@ -17,7 +19,7 @@
from .http import run_http_server
from .logging import init_logger
from .tgis_utils.args import EnvVarArgumentParser, add_tgis_args, postprocess_tgis_args
from .utils import check_for_failed_tasks
from .utils import check_for_failed_tasks, write_termination_log

if TYPE_CHECKING:
import argparse
Expand All @@ -43,10 +45,21 @@ async def start_servers(args: argparse.Namespace) -> None:
)
tasks.append(grpc_server_task)

runtime_error = None
with contextlib.suppress(asyncio.CancelledError):
# Both server tasks will exit normally on shutdown, so we await
# FIRST_COMPLETED to catch either one shutting down.
await asyncio.wait(tasks, return_when=FIRST_COMPLETED)
if engine and engine.errored and not engine.is_running:
# both servers shut down when an engine error
# is detected, with task done and exception handled
# here we just notify of that error and let servers be
runtime_error = RuntimeError(
"AsyncEngineClient error detected,this may be caused by an \
unexpected error in serving a request. \
Please check the logs for more details."
)

# Once either server shuts down, cancel the other
for task in tasks:
task.cancel()
Expand All @@ -55,6 +68,22 @@ async def start_servers(args: argparse.Namespace) -> None:
await asyncio.wait(tasks)

check_for_failed_tasks(tasks)
if runtime_error:
raise runtime_error


def run_and_catch_termination_cause(
loop: asyncio.AbstractEventLoop, task: asyncio.Task
) -> None:
try:
loop.run_until_complete(task)
except Exception:
# Report the first exception as cause of termination
msg = traceback.format_exc()
write_termination_log(
msg, os.getenv("TERMINATION_LOG_DIR", "/dev/termination-log")
)
raise


if __name__ == "__main__":
Expand All @@ -75,4 +104,5 @@ async def start_servers(args: argparse.Namespace) -> None:

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loop = asyncio.new_event_loop()
loop.run_until_complete(start_servers(args))
task = loop.create_task(start_servers(args))
run_and_catch_termination_cause(loop, task)
18 changes: 17 additions & 1 deletion src/vllm_tgis_adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,20 @@ def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
name = task.get_name()
coro_name = task.get_coro().__name__

raise RuntimeError(f"task={name} ({coro_name})") from exc
raise RuntimeError(f"task={name} ({coro_name}) exception={exc!s}") from exc


def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None:
"""Write to the termination logfile."""
# From https://github.com/IBM/text-generation-inference/blob/9388f02d222c0dab695bea1fb595cacdf08d5467/server/text_generation_server/utils/termination.py#L4
try:
with open(file, "w") as termination_file:
termination_file.write(f"{msg}\n")
except Exception:
# Ignore any errors writing to the termination logfile.
# Users can fall back to the stdout logs, and we don't want to pollute
# those with an error here.
from .logging import init_logger

logger = init_logger("vllm-tgis-adapter")
logger.exception("Unable to write termination logs to %s", file)
16 changes: 11 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser

from vllm_tgis_adapter.__main__ import start_servers
from vllm_tgis_adapter.__main__ import run_and_catch_termination_cause, start_servers
from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService
from vllm_tgis_adapter.healthcheck import health_check
from vllm_tgis_adapter.tgis_utils.args import (
Expand Down Expand Up @@ -68,6 +68,11 @@ def disable_frontend_multiprocessing(request):
return request.param


@pytest.fixture
def server_args(request: pytest.FixtureRequest):
return request.param if hasattr(request, "param") else []


@pytest.fixture
def args( # noqa: PLR0913
request: pytest.FixtureRequest,
Expand All @@ -76,11 +81,13 @@ def args( # noqa: PLR0913
http_server_port: ArgFixture[int],
lora_available: ArgFixture[bool],
disable_frontend_multiprocessing,
server_args: ArgFixture[list[str]],
) -> argparse.Namespace:
"""Return parsed CLI arguments for the adapter/vLLM."""
# avoid parsing pytest arguments as vllm/vllm_tgis_adapter arguments

extra_args: list[str] = []
# Extra server init flags
extra_args: list[str] = [*server_args]
if lora_available:
name = request.getfixturevalue("lora_adapter_name")
path = request.getfixturevalue("lora_adapter_path")
Expand Down Expand Up @@ -179,9 +186,8 @@ def dummy_signal_handler(*args, **kwargs):

def target():
nonlocal task

task = loop.create_task(start_servers(args))
loop.run_until_complete(task)
run_and_catch_termination_cause(loop, task)

t = threading.Thread(target=target)
t.start()
Expand All @@ -195,7 +201,7 @@ def target():

t.join()

# rorkaround: Instantiating the TGISStatLogger multiple times creates
# workaround: Instantiating the TGISStatLogger multiple times creates
# multiple Gauges etc which can only be instantiated once.
# By unregistering the Collectors from the REGISTRY we can
# work around this problem.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_completions(http_server_url, _servers):
response = requests.post(
f"{http_server_url}/v1/completions",
json={
"prompt": "The answer tho life the universe and everything is ",
"prompt": "The answer to life the universe and everything is ",
"model": model_id,
},
)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_termination_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from .utils import TaskFailedError


@pytest.fixture
def termination_log_fpath(tmp_path, monkeypatch):
# create termination log before server starts
temp_file = tmp_path / "termination_log.txt"
monkeypatch.setenv("TERMINATION_LOG_DIR", str(temp_file))
yield temp_file
temp_file.unlink(missing_ok=True)


@pytest.mark.parametrize(
"server_args",
[
pytest.param(["--enable-lora"], id="enable-lora"),
pytest.param(["--max-model-len=10241024"], id="huge-model-len"),
pytest.param(["--model=google-bert/bert-base-uncased"], id="unsupported-model"),
],
indirect=True,
)
def test_startup_fails(request, args, termination_log_fpath, lora_available):
"""Test that common set-up errors crash the server on startup.
These errors should be properly reported in the termination log.
"""
if lora_available and args.enable_lora:
pytest.skip("This test requires a non-lora supported device to run")

# Server fixture is called explicitly so that we can handle thrown exception
with pytest.raises(TaskFailedError):
_ = request.getfixturevalue("_servers")

# read termination logs
assert termination_log_fpath.exists()

0 comments on commit 355a088

Please sign in to comment.