You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
当我在mmdet中使用Albu的Pipeline类进行数据增强时候,出现了以下报错:
RuntimeError: Index put requires the source and destination dtypes match, got Long for the destination and Double for the source.
判定是在box_iou的时候生成的类别是int,而赋值的gt_box_label却是float
Error traceback
Traceback (most recent call last):
File "tools/train.py", line 121, in <module>
main()
File "tools/train.py", line 117, in main
runner.train()
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/mmengine/runner/loops.py", line 287, in run
self.run_iter(data_batch)
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/mmengine/runner/loops.py", line 311, in run_iter
outputs = self.runner.model.train_step(
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/mmengine/model/wrappers/distributed.py", line 121, in train_step
losses = self._run_forward(data, mode='loss')
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/mmengine/model/wrappers/distributed.py", line 161, in _run_forward
results = self(**data, mode=mode)
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
else self._run_ddp_forward(*inputs, **kwargs)
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
return self.module(*inputs, **kwargs) # type: ignore[index]
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/cnroge/.conda/envs/ccccnroge/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/cnroge/mmdetection/mmdet/models/detectors/base.py", line 92, in forward
return self.loss(inputs, data_samples)
File "/home/cnroge/mmdetection/mmdet/models/detectors/two_stage.py", line 174, in loss
rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
File "/home/cnroge/mmdetection/mmdet/models/dense_heads/base_dense_head.py", line 165, in loss_and_predict
losses = self.loss_by_feat(*loss_inputs)
File "/home/cnroge/mmdetection/mmdet/models/dense_heads/rpn_head.py", line 125, in loss_by_feat
losses = super().loss_by_feat(
File "/home/cnroge/mmdetection/mmdet/models/dense_heads/anchor_head.py", line 502, in loss_by_feat
cls_reg_targets = self.get_targets(
File "/home/cnroge/mmdetection/mmdet/models/dense_heads/anchor_head.py", line 378, in get_targets
results = multi_apply(
File "/home/cnroge/mmdetection/mmdet/models/utils/misc.py", line 219, in multi_apply
return tuple(map(list, zip(*map_results)))
File "/home/cnroge/mmdetection/mmdet/models/dense_heads/anchor_head.py", line 252, in _get_targets_single
assign_result = self.assigner.assign(pred_instances, gt_instances,
File "/home/cnroge/mmdetection/mmdet/models/task_modules/assigners/max_iou_assigner.py", line 234, in assign
assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
File "/home/cnroge/mmdetection/mmdet/models/task_modules/assigners/max_iou_assigner.py", line 318, in assign_wrt_overlaps
assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
RuntimeError: Index put requires the source and destination dtypes match, got Long for the destination and Double for the source.
mmdet.datasets.transforms.transforms.py
class Albu(BaseTransform):
...
def _postprocess_results(...):
...
(line1774)results[label] = np.array(
[results[label][i] for i in results['idx_mapper']])
我将1774行代码改为了
results[label] = np.array(
[results[label][i] for i in results['idx_mapper']], dtype=np.int64)
保证了numpy转换前后的数据类型一致即可
The text was updated successfully, but these errors were encountered:
cnroge
changed the title
使用Albu进行数据增强的时候,如果传入空gt_box_label会导致mmdet运行出现类型不匹配报错
使用Albu进行数据增强的时候,如果传入空gt_box_label会导致mmdet运行出现类型不匹配报错 got Long for the destination and Double for the source
Sep 14, 2024
Thanks for your error report and we appreciate it a lot.
Checklist
Describe the bug
当我在mmdet中使用Albu的Pipeline类进行数据增强时候,出现了以下报错:
RuntimeError: Index put requires the source and destination dtypes match, got Long for the destination and Double for the source.
判定是在box_iou的时候生成的类别是int,而赋值的gt_box_label却是float
Error traceback
Bug fix
经过排查,发现是数据在经过Albu处理后,如果是空gt_bboxes_labels则会被转换成float64,以下是对应转换的代码处:
我将1774行代码改为了
保证了numpy转换前后的数据类型一致即可
The text was updated successfully, but these errors were encountered: