-
Notifications
You must be signed in to change notification settings - Fork 461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error while trying to save a sharded model checkpoint on multi-host TPU #7919
Comments
we have our own distributed checkpoint mechanism https://github.com/pytorch/xla/blob/master/docs/spmd_distributed_checkpoint.md @jonb377 |
Ah I saw you already use that. @jonb377 can you take a look? |
Are you using a shared mount at You can use any fsspec-compatible checkpoint path (e.g. |
@jonb377 We tried to mount "/tmp/example" using gcsfuse directly to a bucket. it seems like it works since now all the files are saved in the bucket, but somehow one of the distcp files (saved in the 2nd host) is empty, We think this is a bug. This is the files ls from both hosts (As you can see one of them is empty):
With this files state we still get the same stacktrace when trying to load the data (in 1 host metadata file is missing and in the other there are missing weights):
|
This is expected if the model parameters are not sharded across hosts: only the replica 0 shard is written to the checkpoint. Currently we do not load balance the writes for replicated tensor shards, but that's an area of optimization we can pursue.
This is surprising, your implementation looks fine to me. I'll see if I can repro on my end. |
I found two issues:
You can verify the issue by reversing the sharding to I will send a patch to address this. It seems the upstream has moved to a more sophisticated deduplication approach which addresses this, and it also load balances the writes. Thanks @dudulightricks for reporting the issue! |
🐛 Bug
To Reproduce
Here is a short example to reproduce the error, running on vp-16 TPU pod:
Steps to reproduce the behavior:
Expected behavior
The model weights should be printed, and should be the same as before the save.
Stack trace
The code failing with an error:
Environment
Additional context
@JackCaoG This code should follow the basic example in the documentation but it doesn't work. When we tried to examine the saved data it seems that it only save some of it and only on 1 host, while the other host as an empty binary file.
The text was updated successfully, but these errors were encountered: