diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 841c7302a..b3deb511b 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -85,8 +85,8 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos[:, :, kv_seq_len - q_len :] - sin = sin[:, :, kv_seq_len - q_len :] + cos = cos[kv_seq_len - q_len :] + sin = sin[kv_seq_len - q_len :] if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)