Skip to content

Commit

Permalink
minor: add mla fp8 test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Sep 23, 2024
1 parent e4780cf commit 5b0ad9e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ jobs:
run: |
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
- name: Evaluate Data Parallelism Accuracy (TP=2)
timeout-minutes: 10
Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Meta-Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
Expand Down
50 changes: 50 additions & 0 deletions test/srt/test_mla_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestMLA(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tp",
"2",
"--trust-remote-code",
"--kv-cache-dtype",
"fp8_e5m2",
],
)

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)

def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)

metrics = run_eval(args)
assert metrics["score"] >= 0.8


if __name__ == "__main__":
unittest.main()

0 comments on commit 5b0ad9e

Please sign in to comment.