From 160211add59b7be803df3b1ddde7ea98d8984613 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Thu, 28 Sep 2023 18:39:10 +0000 Subject: [PATCH] Revert "Fix retries during inference" This reverts commit 3547b57f975ff3d89b27b2fdf0c64f7d395ad876. --- src/petals/client/inference_session.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 03dc272a1..c89f6b1ae 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -52,7 +52,7 @@ def __init__( self.stepped = False self.closed = False - self.position = 0 + self._position = 0 self.history = None # Used in case of server failures to regenerate attention caches on new servers self.next_session = None @@ -102,11 +102,12 @@ def step( n_input_tokens = inputs.shape[1] if self.history is None: self.history = inputs - elif self.history.shape[1] == self.position: + elif self.history.shape[1] == self._position: self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1) - assert ( - self.history.shape[1] == self.position + n_input_tokens - ), f"Broken input cache: {self.span=} {self.history.shape=} {self.position=} {n_input_tokens=}" + assert self.history.shape[1] == self._position + n_input_tokens, ( + f"Broken input cache: span={self.span} shape={self.history.shape} " + f"position={self._position} n_input_tokens={n_input_tokens}" + ) if not self.stepped: inputs = self.history # Pass full inputs including prefix @@ -173,7 +174,7 @@ def step( outputs[0].shape == inputs.shape ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}" - self.position += n_input_tokens + self._position += n_input_tokens return outputs[0] @@ -363,10 +364,6 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> # If there is a failed span, this code replaces it, otherwise it just adds new ones if server_idx < n_prev_spans: updated_sessions[0].history = self._server_sessions[server_idx].history - updated_sessions[0].position = self._position - assert ( - updated_sessions[0].history.shape[1] == self._position - ), f"Broken input cache: {updated_sessions[0].history.shape=} {self._position=}" self._server_sessions[server_idx : server_idx + 1] = updated_sessions # Update links to the next server session for direct server-to-server communication via rpc_push()