diff --git a/server/text_generation_server/models/paged_causal_lm.py b/server/text_generation_server/models/paged_causal_lm.py index c25cd3d0..3808b3fb 100644 --- a/server/text_generation_server/models/paged_causal_lm.py +++ b/server/text_generation_server/models/paged_causal_lm.py @@ -321,12 +321,10 @@ def __init__( from fms_extras.utils.cache.paged import PagedKVCacheManager if SPECULATOR_PATH is not None: - from fms_extras.models.speculator import MLPSpeculator + from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel print(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}") - self.speculator = MLPSpeculator(model_config.hidden_size, vocab_size=model_config.vocab_size, n_predict=3).to(device=self.device, dtype=dtype) - self.speculator.load_state_dict( - torch.load(SPECULATOR_PATH, map_location=self.device)["model_state"] - ) + self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(SPECULATOR_PATH) + self.speculator.to(device=self.device, dtype=dtype) else: self.speculator = None @@ -340,7 +338,6 @@ def __init__( device=self.device, ) - @property def batch_type(self) -> Type[PagedCausalLMBatch]: return self._batch_type diff --git a/server/text_generation_server/utils/paged.py b/server/text_generation_server/utils/paged.py index d6e72331..6c1eaf4f 100644 --- a/server/text_generation_server/utils/paged.py +++ b/server/text_generation_server/utils/paged.py @@ -144,8 +144,8 @@ def prepare_inputs_with_speculation( n_adds = speculator.n_predict + 1 #hard-code some values - top_k = 5 - threshes=[5, 3, 2] + top_k = speculator.config.n_candidates + threshes= speculator.config.top_k_tokens_per_head flatting=True # create candidate sequences