Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports mps device IndexType #1570

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

bibekyess
Copy link

Motivation

When doing Inference on MMDet CoDINO model, the inference fails raising assertion error in mps device. Specifically this line:

assert isinstance(item, IndexType.__args__)

The issue is that there are no types defined for tensors loaded in mps and mps doesn’t support the following:

AttributeError: module 'torch.mps' has no attribute 'BoolTensor’
AttributeError: module 'torch.mps' has no attribute 'LongTensor’

If I define

elif get_device() == 'mps':
    BoolTypeTensor = Union[torch.BoolTensor, torch.Tensor]
    LongTypeTensor = Union[torch.LongTensor, torch.Tensor]

it throws this error:

File "../site-packages/mmengine/structures/instance_data.py", line 207, in __getitem__
    assert len(item) == len(self), 'The shape of the ' \
AssertionError: The shape of the input(BoolTensor) 257 does not match the shape of the indexed tensor in results_field 300 at first dimension.

So a simple fix is to add torch.Tensor to IndexType types.
Thanks!!

@CLAassistant
Copy link

CLAassistant commented Sep 5, 2024

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants