Skip to content

Commit

Permalink
Revert "Fix retries during inference"
Browse files Browse the repository at this point in the history
This reverts commit 3547b57.
  • Loading branch information
borzunov committed Sep 28, 2023
1 parent 567e34b commit 160211a
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.stepped = False
self.closed = False

self.position = 0
self._position = 0
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None

Expand Down Expand Up @@ -102,11 +102,12 @@ def step(
n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
elif self.history.shape[1] == self.position:
elif self.history.shape[1] == self._position:
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
assert (
self.history.shape[1] == self.position + n_input_tokens
), f"Broken input cache: {self.span=} {self.history.shape=} {self.position=} {n_input_tokens=}"
assert self.history.shape[1] == self._position + n_input_tokens, (
f"Broken input cache: span={self.span} shape={self.history.shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)

if not self.stepped:
inputs = self.history # Pass full inputs including prefix
Expand Down Expand Up @@ -173,7 +174,7 @@ def step(
outputs[0].shape == inputs.shape
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"

self.position += n_input_tokens
self._position += n_input_tokens

return outputs[0]

Expand Down Expand Up @@ -363,10 +364,6 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
# If there is a failed span, this code replaces it, otherwise it just adds new ones
if server_idx < n_prev_spans:
updated_sessions[0].history = self._server_sessions[server_idx].history
updated_sessions[0].position = self._position
assert (
updated_sessions[0].history.shape[1] == self._position
), f"Broken input cache: {updated_sessions[0].history.shape=} {self._position=}"
self._server_sessions[server_idx : server_idx + 1] = updated_sessions

# Update links to the next server session for direct server-to-server communication via rpc_push()
Expand Down

0 comments on commit 160211a

Please sign in to comment.