From 821d1c21f9c2920dbff4c18dde093ff05d8ca8a5 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Tue, 2 Apr 2024 11:18:47 -0400 Subject: [PATCH 1/2] added speculator as hf model --- server/text_generation_server/models/paged_causal_lm.py | 8 +++----- server/text_generation_server/utils/paged.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/paged_causal_lm.py b/server/text_generation_server/models/paged_causal_lm.py index 22e145ce..9b1fc9c3 100644 --- a/server/text_generation_server/models/paged_causal_lm.py +++ b/server/text_generation_server/models/paged_causal_lm.py @@ -333,12 +333,10 @@ def __init__( if SPECULATOR_PATH is not None: - from fms.modules.speculator import Speculator + from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel print(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}") - self.speculator = Speculator(model_config.hidden_size, model_config.vocab_size, n_predict=3).to(self.device) - self.speculator.load_state_dict( - torch.load(SPECULATOR_PATH, map_location=self.device)["model_state"] - ) + self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(SPECULATOR_PATH) + self.speculator.to(device=self.device, dtype=dtype) else: self.speculator = None diff --git a/server/text_generation_server/utils/paged.py b/server/text_generation_server/utils/paged.py index e8f380d7..789a982a 100644 --- a/server/text_generation_server/utils/paged.py +++ b/server/text_generation_server/utils/paged.py @@ -144,8 +144,8 @@ def prepare_inputs_with_speculation( n_adds = speculator.n_predict + 1 #hard-code some values - top_k = 5 - threshes=[5, 3, 2] + top_k = speculator.config.n_candidates + threshes= speculator.config.top_k_tokens_per_head flatting=True # create candidate sequences From 0b25d59d3bf9e5ddc464bfdc2bed6766744e5bff Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Tue, 2 Apr 2024 11:39:37 -0400 Subject: [PATCH 2/2] re-added sample --- scripts/speculative_generation_sample.py | 74 ++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 scripts/speculative_generation_sample.py diff --git a/scripts/speculative_generation_sample.py b/scripts/speculative_generation_sample.py new file mode 100644 index 00000000..7b652274 --- /dev/null +++ b/scripts/speculative_generation_sample.py @@ -0,0 +1,74 @@ +from text_generation_server.models import get_model +from text_generation_server.pb import generate_pb2 +from typing import List +import time +import torch + +template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:" + +text = template.format( + "Provide a list of instructions for preparing chicken soup." +) + +def __generate_prefill_request(id: int, batch_size: int, num_new_tokens: List[int]): + + out = generate_pb2.PrefillRequest( + batch=generate_pb2.Batch( + id=id, + requests=[ + generate_pb2.Request( + id=i, inputs=text, input_length=49, max_output_length=num_new_tokens[i], + parameters=generate_pb2.NextTokenChooserParameters( + temperature=0.0, + ) + ) for i in range(batch_size) + ] + ) + ) + return out + +model = get_model( + model_name="/net/storage149/mnt/md0/jmrosenk/llama_weights/hf/7B-F", + revision=None, + deployment_framework="tgis_native", + dtype_str="float16", + quantize=None, + max_sequence_length=2048 +) + + +num_new_tokens = [100] + +request1 = __generate_prefill_request(0, 1, num_new_tokens) + +batch1, errors = model.batch_type.from_pb( + request1.batch, + tokenizer=model.tokenizer, + dtype=model.dtype, + device=model.device, + embeddings_lookup=model.word_embeddings, + prefix_cache=model.prefix_cache, + use_position_ids=model.use_position_ids, +) + +# token info (token ids) - have tokenizer in model, +result = "" +token_info_list, _, _, _ = model.generate_token(batch1, first=True, for_concat=False) + +cumulative_t = 0 +total_decode_tokens = 0 +token_ids_out = [token_info_list[0].token_id] +while batch1.input_lengths[0] < batch1.total_lengths[0]: + t0 = time.time_ns() + token_info_list, _, _, _ = model.generate_token(batch1) + torch.cuda.synchronize(device=model.device) + t_tok = time.time_ns()-t0 + cumulative_t += t_tok + total_decode_tokens += len(token_info_list) + print(f"number of tokens per step: {len(token_info_list)}") + for token_info in token_info_list: + token_ids_out.append(token_info.token_id) + print("t_tok: %.2f ms" % (t_tok / len(token_info_list) / 1000.0 / 1000.0)) + +print(model.tokenizer.decode(token_ids_out)) +print(f"avg per token: {cumulative_t / total_decode_tokens / 1000000}") \ No newline at end of file