Skip to content

Commit

Permalink
Llama rotary dims from 4 to 2
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Nov 22, 2023
1 parent 2bdbf2d commit fa254cf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]

if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
Expand Down

0 comments on commit fa254cf

Please sign in to comment.