You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Build video swin transformer in torch.bfloat16 and run with following error:
.........
File "/data4/Projects/video_captioning/my_project/experiments/modeling/swin_transformer.py", line 284, in forward
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
RuntimeError: expected scalar type BFloat16 but found Float
Describe the bug
Build video swin transformer in torch.bfloat16 and run with following error:
As I haved change something of this code in my project, the corresponding line in swin_transformer.py of this repository is in https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L166
and related function with bug is https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L317.
Bug fix
One way to fix is shown below:
And change the line in https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py#L405 to
attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device, x.dtype)
The text was updated successfully, but these errors were encountered: