Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while trying to save a sharded model checkpoint on multi-host TPU #7919

Open
dudulightricks opened this issue Aug 28, 2024 · 6 comments
Assignees

Comments

@dudulightricks
Copy link

dudulightricks commented Aug 28, 2024

🐛 Bug

To Reproduce

Here is a short example to reproduce the error, running on vp-16 TPU pod:

import numpy as np
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
import torch.distributed as dist
import torch_xla.distributed.xla_backend
import time

from torch import nn
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

SAVING_PATH="/tmp/example4"

xr.use_spmd()
dist.init_process_group('gloo', init_method='xla://')

num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices//2, 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

model = nn.Linear(in_features=2, out_features=2, device=xm.xla_device())
model_sharded = xs.mark_sharding(model.weight, mesh, (None, 'y'))

print(f"model weight {model.weight} model bias {model.bias}")

state_dict = {
    "my_model": model.state_dict(),
}

dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(SAVING_PATH),
    planner=xc.SPMDSavePlanner(),
)

time.sleep(20)

state_dict = {
    "my_model": model.state_dict(),
}

dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(SAVING_PATH),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["my_model"])

print(f"model weight after load: {model.weight} bias: {model.bias}")

Steps to reproduce the behavior:

  1. Run this code on multi-host TPU (vp-16 in our case)

Expected behavior

The model weights should be printed, and should be the same as before the save.

Stack trace

The code failing with an error:

[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
[rank0]:     _load_state_dict(
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
[rank0]:     central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank0]:     raise result
[rank0]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1])
[rank0]: Traceback (most recent call last): (RANK 0)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank0]:     local_data = map_fun()
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank0]:     local_plan = planner.create_local_plan()
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch_xla/experimental/distributed_checkpoint/planners.py", line 231, in create_local_plan
[rank0]:     xla_read_items = _create_xla_read_items(self.sharded_state_dict,
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch_xla/experimental/distributed_checkpoint/planners.py", line 380, in _create_xla_read_items
[rank0]:     md = metadata.state_dict_metadata[fqn]
[rank0]: KeyError: 'my_model.weight'
[rank0]: Traceback (most recent call last): (RANK 1)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank0]:     local_data = map_fun()
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 205, in local_step
[rank0]:     metadata = storage_reader.read_metadata()
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 677, in read_metadata
[rank0]:     with self.fs.create_stream(path, "rb") as metadata_file:
[rank0]:   File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
[rank0]:     return next(self.gen)
[rank0]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 388, in create_stream
[rank0]:     with cast(Path, path).open(mode) as stream:
[rank0]:   File "/usr/lib/python3.10/pathlib.py", line 1119, in open
[rank0]:     return self._accessor.open(self, mode, buffering, encoding, errors,
[rank0]: FileNotFoundError: [Errno 2] No such file or directory: '/tmp/example4/.metadata'

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU vp-16
  • torch_xla version: 2.4

Additional context

@JackCaoG This code should follow the basic example in the documentation but it doesn't work. When we tried to examine the saved data it seems that it only save some of it and only on 1 host, while the other host as an empty binary file.

@JackCaoG
Copy link
Collaborator

we have our own distributed checkpoint mechanism https://github.com/pytorch/xla/blob/master/docs/spmd_distributed_checkpoint.md @jonb377

@JackCaoG
Copy link
Collaborator

Ah I saw you already use that. @jonb377 can you take a look?

@jonb377
Copy link
Collaborator

jonb377 commented Aug 28, 2024

Are you using a shared mount at /tmp/example4? In torch.distributed.checkpoint, all hosts need to have access to the same directory, since only the master process will write the .metadata file but each host needs the file to operate on the checkpoint.

You can use any fsspec-compatible checkpoint path (e.g. gs://for a GCS bucket) instead of a local path for checkpointing.

@dudulightricks
Copy link
Author

dudulightricks commented Aug 29, 2024

@jonb377 We tried to mount "/tmp/example" using gcsfuse directly to a bucket. it seems like it works since now all the files are saved in the bucket, but somehow one of the distcp files (saved in the 2nd host) is empty, We think this is a bug.

This is the files ls from both hosts (As you can see one of them is empty):

user@t1v-n-47240b2c-w-0:/opt/repo$ ls -la /tmp/example
total 3
-rw-r--r-- 1 user user  884 Aug 29 08:41 .metadata
-rw-r--r-- 1 user user 1180 Aug 29 08:41 __0_0.distcp
-rw-r--r-- 1 user user    0 Aug 29 08:41 __1_0.distcp

With this files state we still get the same stacktrace when trying to load the data (in 1 host metadata file is missing and in the other there are missing weights):

[rank1]: Traceback (most recent call last):
[rank1]:   File "/opt/repo/scripts/sharding_example6.py", line 45, in <module>
[rank1]:     dist_cp.load(
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank1]:     result = func(*args, **kwargs)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 434, in inner_func
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
[rank1]:     _load_state_dict(
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
[rank1]:     central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank1]:     raise result
[rank1]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1])
[rank1]: Traceback (most recent call last): (RANK 0)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank1]:     local_data = map_fun()
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank1]:     result = func(*args, **kwargs)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank1]:     local_plan = planner.create_local_plan()
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch_xla/experimental/distributed_checkpoint/planners.py", line 231, in create_local_plan
[rank1]:     xla_read_items = _create_xla_read_items(self.sharded_state_dict,
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch_xla/experimental/distributed_checkpoint/planners.py", line 380, in _create_xla_read_items
[rank1]:     md = metadata.state_dict_metadata[fqn]
[rank1]: KeyError: 'my_model.weight'
[rank1]: Traceback (most recent call last): (RANK 1)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank1]:     local_data = map_fun()
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank1]:     result = func(*args, **kwargs)
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 205, in local_step
[rank1]:     metadata = storage_reader.read_metadata()
[rank1]:   File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 678, in read_metadata
[rank1]:     metadata = pickle.load(metadata_file)
[rank1]: FileNotFoundError: [Errno 2] No such file or directory

@jonb377
Copy link
Collaborator

jonb377 commented Aug 29, 2024

one of the distcp files (saved in the 2nd host) is empty

This is expected if the model parameters are not sharded across hosts: only the replica 0 shard is written to the checkpoint. Currently we do not load balance the writes for replicated tensor shards, but that's an area of optimization we can pursue.

CheckpointException ranks:dict_keys([0, 1])

This is surprising, your implementation looks fine to me. I'll see if I can repro on my end.

@jonb377
Copy link
Collaborator

jonb377 commented Aug 29, 2024

I found two issues:

  • gcsfuse does not support file streaming well, and it treats the .metadata file as not found when it actually exists. The fsspec filesystem classes are still private, but you can try those as a workaround for this issue. You'll need to set a couple of flags for compatibility with gcsfs: sync_files=False, per_thread_copy_ahead=0,
  • There is an issue in our use of the upstream deduplication logic, where if a single host wants to write the same shard twice, both writes are eliminated and the shard is dropped from the checkpoint. In your case, model.weight shards are replicated on e.g. devices 0, 2, 4, 6, and since a single process controls devices 0 and 2, these writes are dropped.

You can verify the issue by reversing the sharding to xs.mark_sharding(model.weight, mesh, (None, 'x')) , which will cause the replication to be across hosts instead of within-host. In this case, checkpointing will work as expected.

I will send a patch to address this. It seems the upstream has moved to a more sophisticated deduplication approach which addresses this, and it also load balances the writes.

Thanks @dudulightricks for reporting the issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants