Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xm.send_cpu_data_to_device cannot support 2d data and 4d mesh #8012

Open
fengyang0317 opened this issue Sep 14, 2024 · 1 comment
Open

xm.send_cpu_data_to_device cannot support 2d data and 4d mesh #8012

fengyang0317 opened this issue Sep 14, 2024 · 1 comment
Assignees

Comments

@fengyang0317
Copy link

🐛 Bug

xm.send_cpu_data_to_device cannot support 2d data and 4d mesh

To Reproduce

https://colab.research.google.com/drive/1URZVd3q0LUeZ8anzrkoJkehOxLdbG1k8?usp=sharing

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla import runtime as xr
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding

xr.use_spmd()

num_devices = xr.global_runtime_device_count()
device = xm.xla_device()
device_ids = np.arange(num_devices)
mesh = xs.Mesh(device_ids, (2, 1, 2, 2), ('dp', 'fsdp', 'tp', 'sp'))

xt = torch.zeros([8, 64]).to(device)
xs.mark_sharding(xt, mesh, ('dp', 'sp'))
print(torch_xla._XLAC._get_xla_sharding_spec(xt))
print(visualize_tensor_sharding(xt))  # This is the desired sharding

xt = torch.zeros([8, 64])
yt = xm.send_cpu_data_to_device(xt, device, input_sharding=xs.ShardingSpec(mesh, ('dp', 'sp')))[0]

Steps to reproduce the behavior:

  1. run the colab

Expected behavior

run without error

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: nightly

Additional context

@bhavya01 bhavya01 self-assigned this Sep 16, 2024
@bhavya01
Copy link
Collaborator

Will look into this later today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants