From 6e95d2ae3e58decee7cd0ebb849625d6d4a91d53 Mon Sep 17 00:00:00 2001 From: spicysama Date: Thu, 12 Sep 2024 21:19:01 +0800 Subject: [PATCH] Fix breakdown infer (#534) * fully support ormsgpack * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * dependency * torch==2.4.1 windows compilable * Update docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove autorerank * api usage * back slash * fix docs * Fix infer warmup params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * max_new_tokens=1024 * Fix break down infer --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tools/llama/generate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/llama/generate.py b/tools/llama/generate.py index ad9c5499..d717ce7c 100644 --- a/tools/llama/generate.py +++ b/tools/llama/generate.py @@ -605,7 +605,7 @@ def worker(): multiple=True, ) @click.option("--num-samples", type=int, default=1) -@click.option("--max-new-tokens", type=int, default=0) +@click.option("--max-new-tokens", type=int, default=1024) @click.option("--top-p", type=float, default=0.7) @click.option("--repetition-penalty", type=float, default=1.2) @click.option("--temperature", type=float, default=0.7) @@ -650,7 +650,10 @@ def main( model, decode_one_token = load_model( checkpoint_path, device, precision, compile=compile ) - + with torch.device(device): + model.setup_caches( + max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype + ) if torch.cuda.is_available(): torch.cuda.synchronize()