Registered buffers not moved to correct device when using DeepSpeed Stage 3 #20258

amorehead opened this issue Sep 6, 2024 · 0 comments
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x


Bug description

Using the DeepSpeed Strategy configuration

_target_: lightning.pytorch.strategies.DeepSpeedStrategy
zero_optimization: true
stage: 3
allgather_bucket_size: 2e8
reduce_bucket_size: 2e8
offload_optimizer: false
offload_parameters: false
partition_activations: false
cpu_checkpointing: false
contiguous_gradients: false
overlap_comm: false

I am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-nn.Modules of my LightningModule's main nn.Module are not moved by register_buffer() to the correct device upon training the In particular, I am trying to register buffers as

distance_bins_tensor = tensor([0.0, 1.0, 2.0, 3.0])
self.register_buffer("distance_bins", distance_bins_tensor)

within the various submodules of my When my optimizer tries to perform a step, I get the error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!

when trying to use these registered buffers e.g., by multiplying them by feature tensors loaded onto (in this case) cuda:6.

How to reproduce the bug

Error messages and logs

# Error messages and logs here please


Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100 80GB PCIe
    - NVIDIA A100 80GB PCIe
    - available: True
    - version: 11.8
  • Lightning:
    - adam-atan2-pytorch: 0.0.10
    - alphafold3-pytorch: 0.0.41
    - alphafold3-pytorch-lightning-hydra: 0.1.111
    - frame-averaging-pytorch: 0.0.19
    - lightning: 2.4.0
    - lightning-utilities: 0.11.6
    - pytorch-lightning: 2.4.0
    - rotary-embedding-torch: 0.6.1
    - torch: 2.3.0+cu118
    - torch-geometric: 2.5.3
    - torchaudio: 2.3.0+cu118
    - torchmetrics: 1.4.1
    - torchtyping: 0.1.4
    - torchvision: 0.18.0+cu118
  • Packages:
    - adam-atan2-pytorch: 0.0.10
    - aiofiles: 23.2.1
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - alembic: 1.13.1
    - alphafold3-pytorch: 0.0.41
    - alphafold3-pytorch-lightning-hydra: 0.1.111
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.4.0
    - appdirs: 1.4.4
    - argcomplete: 3.3.0
    - asttokens: 2.4.1
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - autopage: 0.5.2
    - beartype: 0.18.5
    - beautifulsoup4: 4.12.3
    - biopandas: 0.5.1.dev0
    - biopython: 1.83
    - bioservices: 1.11.2
    - cattrs: 23.2.3
    - certifi: 2024.8.30
    - cfgv: 3.4.0
    - chardet: 5.2.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - cliff: 4.7.0
    - cmaes: 0.10.0
    - cmd2: 2.4.3
    - colorama: 0.4.6
    - colorlog: 6.8.2
    - colt5-attention: 0.11.0
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - deepdiff: 7.0.1
    - deepspeed: 0.15.0
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - easydev: 0.13.2
    - einops: 0.8.0
    - einx: 0.2.2
    - environs: 11.0.0
    - exceptiongroup: 1.2.1
    - executing: 2.0.1
    - fastapi: 0.112.2
    - ffmpy: 0.4.0
    - filelock: 3.13.1
    - fonttools: 4.52.4
    - frame-averaging-pytorch: 0.0.19
    - freetype-py: 2.3.0
    - frozendict: 2.4.4
    - frozenlist: 1.4.1
    - fsspec: 2024.2.0
    - gemmi: 0.6.6
    - gevent: 24.2.1
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - gradio: 4.43.0
    - gradio-client: 1.3.0
    - gradio-molecule3d: 0.0.5
    - graphein: 1.7.6
    - greenlet: 3.0.3
    - grequests: 0.7.0
    - h11: 0.14.0
    - hjson: 3.1.0
    - httpcore: 1.0.5
    - httpx: 0.27.2
    - huggingface-hub: 0.23.4
    - hydra-colorlog: 1.2.0
    - hydra-core: 1.3.2
    - hydra-optuna-sweeper: 1.2.0
    - identify: 2.5.36
    - idna: 3.7
    - importlib-resources: 6.4.4
    - iniconfig: 2.0.0
    - ipykernel: 6.29.4
    - ipython: 8.24.0
    - jaxtyping: 0.2.28
    - jedi: 0.19.1
    - jinja2: 3.1.3
    - joblib: 1.4.2
    - jupyter-client: 8.6.2
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.5
    - lightning: 2.4.0
    - lightning-utilities: 0.11.6
    - line-profiler: 4.1.3
    - local-attention: 1.9.1
    - loguru: 0.7.2
    - looseversion: 1.1.2
    - lxml: 5.2.2
    - mako: 1.3.5
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - marshmallow: 3.21.3
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.7
    - mdurl: 0.1.2
    - mmtf-python: 1.1.3
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - multipledispatch: 1.0.0
    - munkres: 1.1.4
    - nest-asyncio: 1.6.0
    - networkx: 3.2.1
    - ninja:
    - nodeenv: 1.8.0
    - numpy: 1.23.5
    - nvidia-cublas-cu11:
    - nvidia-cuda-cupti-cu11: 11.8.87
    - nvidia-cuda-nvrtc-cu11: 11.8.89
    - nvidia-cuda-runtime-cu11: 11.8.89
    - nvidia-cudnn-cu11:
    - nvidia-cufft-cu11:
    - nvidia-curand-cu11:
    - nvidia-cusolver-cu11:
    - nvidia-cusparse-cu11:
    - nvidia-ml-py: 12.560.30
    - nvidia-nccl-cu11: 2.20.5
    - nvidia-nvtx-cu11: 11.8.86
    - omegaconf: 2.3.0
    - optree: 0.11.0
    - optuna: 2.10.1
    - ordered-set: 4.1.0
    - orjson: 3.10.7
    - packaging: 24.0
    - pandas: 1.5.3
    - parso: 0.8.4
    - pbr: 6.0.0
    - pdbeccdutils: 0.8.5
    - pexpect: 4.9.0
    - pillow: 10.2.0
    - pip: 24.0
    - pipx: 1.5.0
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 1.3.0
    - pre-commit: 3.7.1
    - prettytable: 3.10.0
    - prompt-toolkit: 3.0.45
    - protobuf: 4.25.4
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - pycairo: 1.26.0
    - pydantic: 2.8.2
    - pydantic-core: 2.20.1
    - pydub: 0.25.1
    - pygments: 2.18.0
    - pyparsing: 3.1.2
    - pyperclip: 1.8.2
    - pytest: 8.2.1
    - python-dateutil: 2.9.0
    - python-dotenv: 1.0.1
    - python-multipart: 0.0.9
    - pytorch-lightning: 2.4.0
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - rdkit: 2024.3.2
    - reportlab: 4.1.0
    - requests: 2.32.2
    - requests-cache: 1.2.0
    - retrying: 1.3.4
    - rich: 13.7.1
    - rich-click: 1.8.2
    - rlpycairo: 0.2.0
    - rootutils: 1.0.7
    - rotary-embedding-torch: 0.6.1
    - ruff: 0.6.4
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - seaborn: 0.13.2
    - semantic-version: 2.10.0
    - sentry-sdk: 2.12.0
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - sh: 2.0.7
    - shellingham: 1.5.4
    - shortuuid: 1.0.13
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - soupsieve: 2.5
    - sqlalchemy: 2.0.30
    - stack-data: 0.6.3
    - starlette: 0.38.4
    - stevedore: 5.2.0
    - suds-community: 1.1.2
    - sympy: 1.12
    - taylor-series-linear-attention: 0.1.12
    - tenacity: 8.3.0
    - threadpoolctl: 3.5.0
    - timeout-decorator: 0.5.0
    - tomli: 2.0.1
    - tomlkit: 0.12.0
    - torch: 2.3.0+cu118
    - torch-geometric: 2.5.3
    - torchaudio: 2.3.0+cu118
    - torchmetrics: 1.4.1
    - torchtyping: 0.1.4
    - torchvision: 0.18.0+cu118
    - tornado: 6.4
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - triton: 2.3.0
    - typeguard: 2.13.3
    - typer: 0.12.5
    - typing-extensions: 4.11.0
    - tzdata: 2024.1
    - unicodedata2: 15.1.0
    - url-normalize: 1.4.3
    - urllib3: 2.2.1
    - userpath: 1.9.2
    - uvicorn: 0.30.6
    - virtualenv: 20.26.2
    - wandb: 0.16.6
    - wcwidth: 0.2.13
    - websockets: 12.0
    - wget: 3.2
    - wheel: 0.43.0
    - wrapt: 1.16.0
    - xarray: 2024.3.0
    - xmltodict: 0.13.0
    - yarl: 1.9.4
    - zope.event: 5.0
    - zope.interface: 6.4.post2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 4.18.0-553.16.1.el8_10.x86_64
    - version: Proposal for help #1 SMP Thu Aug 8 07:11:46 EDT 2024

More info

@amorehead amorehead added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 6, 2024
