Skip to content

Commit

Permalink
revert gpu->cpu async copy due to slowdown
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Jan 9, 2024
1 parent 9a7256a commit 18a789c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
PROMPT_SEQEUNCE_INDEX,
get_prompt_sequence_id,
)
from .sampling_params import SamplingParams, SamplingType
from .sampling_params import SamplingParams, SamplingType, _SAMPLING_EPS as SAMPLING_EPS
18 changes: 9 additions & 9 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,15 @@ def __init__(
self.sliding_window = config.sliding_window
self.num_shards = config.num_shards

if config.model_type == "llama":
self.torch_dtype = torch.float32
elif config.model_type == "mistral" or config.model_type == "mixtral":
self.torch_dtype = torch.float32
else:
assert 0, f"{config.model_type} is NOT supported yet"

self._copy_stream: torch.cuda.Stream = torch.cuda.Stream()
self._greedy_sampling_stream: torch.cuda.Stream = torch.cuda.Stream()
self._random_sampling_stream: torch.cuda.Stream = torch.cuda.Stream()
self.torch_dev = torch.from_dlpack(
tvm.nd.array(np.array([], dtype="float32"), self.dev)
).device
self.torch_dev = "cuda"

if self.sliding_window:
self.block_sliding_window = self.sliding_window // CacheManager.block_size
Expand Down Expand Up @@ -344,9 +347,8 @@ def generate(
# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
# TODO(@sunggg): fix the datatype
sampling_tensors = get_tensors_for_sampling(
sampling_params, torch.float32, self.torch_dev, self.vocab_size
sampling_params, self.torch_dtype, self.torch_dev, self.vocab_size
)

# Last synchronization point for model execution
Expand Down Expand Up @@ -383,8 +385,6 @@ def generate(
logits,
sampling_tensors,
self.vocab_size,
self._greedy_sampling_stream,
self._random_sampling_stream,
)

# assert next_tokens is not None
Expand Down
72 changes: 41 additions & 31 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tvm
from ..engine import (
SamplingType,
SamplingParams,
SAMPLING_EPS,
)

LOG = structlog.stdlib.get_logger(__name__)
Expand Down Expand Up @@ -80,6 +80,8 @@ def get_tensors_for_sampling(sampling_params, dtype, dev, vocab_size):
do_top_k = False
has_random = False
has_greedy = False
apply_penalty = False
apply_bias = False
frequency_penalties = []
presence_penalties = []
rep_penalties = []
Expand All @@ -88,24 +90,41 @@ def get_tensors_for_sampling(sampling_params, dtype, dev, vocab_size):
past_output_tokens = []
for param in sampling_params:
# Prepare temperature
temperatures.append(param.temperature)
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperatures.append(
param.temperature if param.temperature >= SAMPLING_EPS else 1.0
)

if param.sampling_type == SamplingType.RANDOM:
has_random |= True
top_ps.append(param.top_p)
top_ks.append(param.top_k if param.top_k != -1 else vocab_size)
do_top_p |= top_ps[-1] < 1.0
do_top_p |= top_ps[-1] < 1.0 - SAMPLING_EPS
do_top_k |= top_ks[-1] != vocab_size
else:
has_greedy |= True

past_output_tokens.append(param.output_tokens)

apply_penalty |= (
abs(param.presence_penalty) >= SAMPLING_EPS
or abs(param.frequency_penalty >= SAMPLING_EPS)
or abs(param.repetition_penalty - 1.0) >= SAMPLING_EPS
)
frequency_penalties.append(param.frequency_penalty)
presence_penalties.append(param.presence_penalty)
assert param.repetition_penalty != 0
rep_penalties.append(param.repetition_penalty)

past_output_tokens.append(param.output_tokens)
logit_bias_indices.append(param.logit_bias_index)
logit_bias_values.append(param.logit_bias_value)
if param.logit_bias_index:
assert param.logit_bias_value
apply_bias |= True
logit_bias_indices.append(param.logit_bias_index)
logit_bias_values.append(param.logit_bias_value)
else:
logit_bias_indices.append([])
logit_bias_values.append([])

temp_t = torch.tensor(
temperatures,
Expand Down Expand Up @@ -151,8 +170,6 @@ def get_tensors_for_sampling(sampling_params, dtype, dev, vocab_size):
device="cpu",
pin_memory=True,
)
apply_penalty = True
apply_bias = True
logit_bias_indices_t = torch.tensor(
logit_bias_indices,
dtype=torch.long,
Expand Down Expand Up @@ -190,8 +207,6 @@ def sample(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sampling_tensors,
vocab_size,
greedy_sampling_stream,
random_sampling_stream,
check_safety=False,
) -> Optional[np.ndarray]:
def _is_safe_to_sample(prob_like):
Expand Down Expand Up @@ -253,30 +268,25 @@ def _is_safe_to_sample(prob_like):
)
res_greedy, res_random = np.array([]), np.array([])

with torch.cuda.stream(greedy_sampling_stream):
if has_greedy:
logits_greedy = logits[mask_greedy_t]
res_greedy = torch.argmax(logits_greedy, -1)
res_greedy = res_greedy.to(torch.device("cpu"), non_blocking=True).numpy()

with torch.cuda.stream(random_sampling_stream):
if has_random:
logits_random = logits[mask_random_t]
# Further adjust logits with the factors related to random sampling
logits_random.div_(temp_t[mask_random_t].unsqueeze(dim=1))
if apply_top_p_top_k:
logits = _apply_top_p_top_k(logits_random, top_ps_t, top_ks_t)
if has_greedy:
logits_greedy = logits[mask_greedy_t]
res_greedy = torch.argmax(logits_greedy, -1)
res_greedy = res_greedy.cpu().numpy()

probs = torch.softmax(logits_random, dim=-1)
if has_random:
logits_random = logits[mask_random_t]
# Further adjust logits with the factors related to random sampling
logits_random.div_(temp_t[mask_random_t].unsqueeze(dim=1))
if apply_top_p_top_k:
logits = _apply_top_p_top_k(logits_random, top_ps_t, top_ks_t)

if check_safety and not _is_safe_to_sample(probs):
return None
probs = torch.softmax(logits_random, dim=-1)

res_random = _multinomial(probs, 1)[:, 0]
res_random = res_random.to(torch.device("cpu"), non_blocking=True).numpy()
if check_safety and not _is_safe_to_sample(probs):
return None

torch.cuda.current_stream().wait_stream(greedy_sampling_stream)
torch.cuda.current_stream().wait_stream(random_sampling_stream)
res_random = _multinomial(probs, 1)[:, 0]
res_random = res_random.cpu().numpy()

# Prepare output
sequence_ids = np.array(sequence_ids)
Expand Down

0 comments on commit 18a789c

Please sign in to comment.