Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Petals doesn't deal with server failure properly #587

Open
oldcpple opened this issue Jun 26, 2024 · 4 comments
Open

Petals doesn't deal with server failure properly #587

oldcpple opened this issue Jun 26, 2024 · 4 comments

Comments

@oldcpple
Copy link

oldcpple commented Jun 26, 2024

Hi there, we'd like to report our findings on testing Petals' availability of fault tolerance.

We note that the current implementation of the method step in the class _ServerInferenceSession from inference_session.py contains the following content:

if self.history is None:
    self.history = inputs
elif self.history.shape[1] == self._position:
    self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)

assert self.history.shape[1] == self._position + n_input_tokens, (
    f"Broken input cache: span={self.span} shape={self.history.shape} "
    f"position={self._position} n_input_tokens={n_input_tokens}"
)

where the attributes self.history and self.position are initialized as None and 0 respectively when an object of the class _ServerInferenceSession is created. The problem is, when a server fails, Petals replaces it with another server that serves the same blocks. However, the new server session is just initialized when joining the inference session, and its attribute position is 0.

In the method _update_sequence of the class InferenceSession, the new server session's history will be assigned the history of the failed server session:
updated_sessions[0].history = self._server_sessions[server_idx].history
And during the inference, n_input_tokens will always be 1. Thus, the assertion:
assert self.history.shape[1] == self._position + n_input_tokens
is always likely to throw exceptions of "Broken input cache".

One possible solution is described as follow:
Delete the assert statement so that no exception will be thrown during the recovery process. Then, change the last few lines of code of method step to:

       self._position += n_input_tokens
       s1 = outputs[0].shape[1]
       if self.recover:
           return outputs[0][0:1, s1 - 1:s1, :]
           #return outputs[0][0][outputs[0].shape[1] - 1]
       return outputs[0]

In which the self.recover is a newly difined attribute of class _ServerInferenceSession, representing whether or not this server session is to recover from a failed one, initialized as False, and will be set to True in the method _update_sequence. This change is to tackle the problem that: when simply delete the assert statement, the returned value of outputs[0] in the recoverd session will be a tensor of shape [1, (num of it's history inputs), 8192] instead of expected [1,1,8192].

By testing tens of examples, we believe this change work properly when dealing with server failures. The final outputs in case some server fail, are the same as the ones where no server fails.

Please check if there are such problems. Many thanks.

@justheuristic
Copy link
Collaborator

justheuristic commented Jul 11, 2024

@oldcpple Thank you! (and sorry for the delayed response)

Your reasoning looks sound. If you still have bandwidth, would you kindly provide a minimal example where this assert fails? (so we can reproduce it for CI tests) That would ensure that, once fixed, this error does not resurface.
I tried to check this via debugprints using our mini-swarm from CI, by shutting down servers during inference, but assert would evaluate true all of 10-ish times it tested it.

If you don't, please reply as such, and I will continue trying.

Also, if you prefer implementing the fix in your own PR, you are most welcome to do that. If that's inconvenient, we will add you as "co-authored by" once we are certain that this fix does not have negative side-effects.

@justheuristic
Copy link
Collaborator

justheuristic commented Jul 11, 2024

Discussed this with @borzunov (core author) : he also believes that you are correct and the problem was introduced during one of our refactors. He hopes to make a pass this weekend (please note that this is not a binding promise, we are unfortunately still overwhelmed with research routine).

@oldcpple
Copy link
Author

oldcpple commented Jul 12, 2024

@justheuristic Thanks for your reply!
We were running a StableBeluga2 model on a 10×2080Ti private swarm, and made sure that each block of the model was held by at least 2 servers. During an inference, we disconneted one of the servers on the server chain selected by the client, it would take a period of time to find another server to replace the disconneted one, and that was when the assert failed.
We tried to attribute the reason and simply fixed it with a few lines of code, as described above. Now the system can successfully recover the interrupted inference and the final result will be the same as the one where no server disconnected. But we haven't varified whether the change we made in inference_session.py would affect other parts.

As you mentioned that your tests ended with correct results, we guess it might be a model-specified problem (but haven't varified this). Note that we met this problem with StableBeluga2, we further found another problem with this model: the batch size seems to be fixed. For example, we first tried to send a request with 2 prompts(batch size=2), the procedure of inference was fine. Then for the following requests, we changed batch size to 10 or any other except 2, similar "Broken input cache" exceptions would occured. We tested for a few times, and was sure that the batch size the model can process must be the same as the batch size of the first inference request. But we believe this as a problem of StableBeluga2, since everything was fine in our test on another model, bloom-7b.

@justheuristic
Copy link
Collaborator

For the record: the issue certainly exists and I can confirm that it breaks fault tolerance on my side. We have not yet fixed it as the issue is bouncing between myself and @borzunov promising to "certainly fix this next weekend night". We hope to fix it and post an update, but everything takes unreasonably long at the moment. We are sorry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants