diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 42130e4926..c0a247a28f 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -239,6 +239,7 @@ jobs: run: | cd test/srt python3 test_mla.py + python3 test_mla_fp8.py - name: Evaluate Data Parallelism Accuracy (TP=2) timeout-minutes: 10 diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 1181ca24e8..fbe45bb2ff 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -26,6 +26,7 @@ DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" +DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Meta-Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" diff --git a/test/srt/test_mla_fp8.py b/test/srt/test_mla_fp8.py new file mode 100644 index 0000000000..37275d696b --- /dev/null +++ b/test/srt/test_mla_fp8.py @@ -0,0 +1,50 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestMLA(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tp", + "2", + "--trust-remote-code", + "--kv-cache-dtype", + "fp8_e5m2", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +if __name__ == "__main__": + unittest.main()