Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several Improvements for the latest PyTorch Framework #1564

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ def compile_model(
Returns:
nn.Module: Compiled model.
"""
if isinstance(compile, bool) and not compile:
if isinstance(compile, bool) and not compile or \
isinstance(compile, dict) and not compile.get('disable', False):
return model

assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None:
``optimizer.state_dict()``
"""
optim_state_dict = FSDP.optim_state_dict_to_load(
state_dict, self.model, self.optim_wrapper.optimizer)
self.model, self.optim_wrapper.optimizer, state_dict)
self.optim_wrapper.load_state_dict(optim_state_dict)

def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from functools import partial
from typing import Union

import torch
Expand All @@ -17,7 +18,8 @@
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler as amp_GradScaler
GradScaler = partial(amp_GradScaler, device='cuda')


@OPTIM_WRAPPERS.register_module()
Expand Down
15 changes: 13 additions & 2 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import logging
import os
import os.path as osp
Expand Down Expand Up @@ -902,8 +903,18 @@ def wrap_model(
find_unused_parameters=find_unused_parameters)
else:
model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')
model_wrapper_type = MODEL_WRAPPERS.get(
model_wrapper_cfg.get('type')) # type: ignore
model_wrapper_type = model_wrapper_cfg.get('type')
if isinstance(model_wrapper_type, str):
model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore
elif inspect.isclass(model_wrapper_type):
pass
else:
raise KeyError(
f'{model_wrapper_type} is not in the '
'registry. Please check whether the value of '
f'`{model_wrapper_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
default_args: dict = dict()
if issubclass(
model_wrapper_type, # type: ignore
Expand Down