Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] When I tried using RMosaic data augmentation and set the workflow to [('train', 1), ('val', 1)], I encountered an error that showed 'KeyError: img'. However, when I changed the workflow back to [('train', 1)], the error disappeared. Why is this happening? #1063

Open
3 tasks done
wu325 opened this issue Sep 4, 2024 · 0 comments

Comments

@wu325
Copy link

wu325 commented Sep 4, 2024

Prerequisite

Task

I have modified the scripts/configs, or I'm working on my own tasks/models/datasets.

Branch

master branch https://github.com/open-mmlab/mmrotate

Environment

fatal: not a git repository (or any of the parent directories): .git
sys.platform: win32
Python: 3.8.19 (default, Mar 20 2024, 19:55:45) [MSC v.1916 64 bit (AMD64)]
CUDA available: True
GPU 0: NVIDIA GeForce RTX 3060 Ti
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.40.33811 版
GCC: n/a
PyTorch: 1.8.0+cu111
PyTorch compiling details: PyTorch built with:

  • C++ Version: 199711
  • MSVC 192829337
  • Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  • OpenMP 2019
  • CPU capability usage: AVX2
  • CUDA Runtime 11.1
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  • CuDNN 8.0.5
  • Magma 2.5.4
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=C:/w/b/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -openmp:exper
    imental -DNDEBUG -DUSE_FBGEMM -DUSE_XNNPACK, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON,

TorchVision: 0.9.0+cu111
OpenCV: 4.10.0
MMCV: 1.6.0
MMCV Compiler: MSVC 192930137
MMCV CUDA Compiler: 11.1
MMRotate: 0.3.4+

Reproduces the problem - code sample

新配置继承基础模型的设置

from mmdet.core.evaluation import class_names

base = '../../configs/lsknet/lsk_s_fpn_1x_dota_le90.py'

1.数据集设置

dataset_type = 'DOTADataset'
data_root = 'data/test/'
angle_version = 'le90'
classes = ('dam', 'groyne', 'lock', 'sluice', 'weir')
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
dict(
type='RMosaic', # 马赛克的数据增强
img_scale=(512, 512),
center_ratio_range=(0.8, 1.2),
min_bbox_size=6,
bbox_clip_border=True,
skip_filter=True,
version=angle_version),
dict(
type='PolyRandomRotate', # 对图像和边界框(bbox)进行旋转的数据增强
rotate_ratio=0.5,
angles_range=180,
auto_bound=False,
version=angle_version),
dict(
type='RandomPhotoMetricDistortion', # 光度变换的数据增强
prob=0.5,
brightness_delta=0, # 亮度变化范围
contrast_range=(0.9, 1.1), # 对比度变化范围
saturation_range=(0.9, 1.1), # 饱和度变化范围
hue_delta=18), # 色调变化范围
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=4,
train=dict(
delete=True,
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + 'train/annfiles/',
img_prefix=data_root + 'train/images/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
version=angle_version
),
pipeline=train_pipeline,
max_refetch=500),
val=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + 'val/annfiles/',
img_prefix=data_root + 'val/images/'),
test=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + 'test/annfiles/',
img_prefix=data_root + 'test/images/'))

2. 优化器设置

optimizer = dict(
delete=True,
type='AdamW',
lr=0.0001, #/8*gpu_number,
betas=(0.9, 0.999),
weight_decay=0.05)
lr_config = dict(
delete=True,
policy='CosineAnnealing',
by_epoch=False,
min_lr=0,
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3)
evaluation = dict(interval=1, metric='mAP', save_best='mAP')
runner = dict(type='EpochBasedRunner', max_epochs=12)

3. 结果输出设置

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
workflow = [('train',1),('val',1)]

4. 模型设置

gpu_number = 1 # 修改为自己的GPU数量
model = dict(
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint='checkpoints/lsk_s_backbone.pth.tar'),
norm_cfg=dict(type='BN', requires_grad=True)), # 单卡训练时将SyncBN改为BN
roi_head=dict(
bbox_head=dict(
num_classes=5)))

Reproduces the problem - command or script

python tools/train.py my_demo/configs/my_lsk_s_orcnn_fpn_1x_dota_le90.py --work-dir my_demo/output/debug/lsk_orcnn

Reproduces the problem - error message

2024-09-04 11:10:24,555 - mmrotate - INFO - workflow: [('train', 1), ('val', 1)], max: 12 epochs
2024-09-04 11:10:24,555 - mmrotate - INFO - Checkpoints will be saved to G:\mmrotate\my_demo\output\debug\lsk_orcnn by HardDiskBackend.
D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\models\dense_heads\anchor_head.py:123: UserWarning: DeprecationWarning: anchor_generator is deprecated, please use "prior_generator" instead
warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
2024-09-04 11:11:20,375 - mmrotate - INFO - Epoch [1][50/95] lr: 3.969e-05, eta: 0:20:14, time: 1.114, data_time: 0.271, memory: 7126, loss_rpn_cls: 0.5340, loss_rpn_bbox: 0.1305, loss_cls: 0.3892, acc: 92.8320, loss_bbox: 0.0319, loss: 1.0856, grad_norm: 12.6486
2024-09-04 11:11:51,915 - mmrotate - INFO - Saving checkpoint at 1 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 7/7, 1.0 task/s, elapsed: 7s, ETA: 0s2024-09-04 11:12:11,302 - mmrotate - INFO -
+--------+-----+------+--------+-------+
| class | gts | dets | recall | ap |
+--------+-----+------+--------+-------+
| dam | 2 | 129 | 0.000 | 0.000 |
| groyne | 0 | 6 | 0.000 | 0.000 |
| lock | 0 | 3 | 0.000 | 0.000 |
| sluice | 5 | 8 | 0.000 | 0.000 |
| weir | 0 | 15 | 0.000 | 0.000 |
+--------+-----+------+--------+-------+
| mAP | | | | 0.000 |
+--------+-----+------+--------+-------+
2024-09-04 11:12:11,340 - mmrotate - INFO - Exp name: my_lsk_s_orcnn_fpn_1x_dota_le90.py
2024-09-04 11:12:11,340 - mmrotate - INFO - Epoch(val) [1][7] mAP: 0.0000
Traceback (most recent call last):
File "tools/train.py", line 192, in
main()
File "tools/train.py", line 181, in main
train_detector(
File "g:\mmrotate\mmrotate\apis\train.py", line 141, in train_detector
runner.run(data_loaders, cfg.workflow)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in run
epoch_runner(data_loaders[i], **kwargs)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 68, in val
for i, data_batch in enumerate(self.data_loader):
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data\dataloader.py", line 517, in next
data = self._next_data()
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data\dataloader.py", line 1199, in _next_data
return self._process_data(data)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data\dataloader.py", line 1225, in _process_data
data.reraise()
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch_utils.py", line 429, in reraise
raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data_utils\worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\custom.py", line 218, in getitem
data = self.prepare_train_img(idx)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\custom.py", line 241, in prepare_train_img
return self.pipeline(results)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\pipelines\compose.py", line 41, in call
data = t(data)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\pipelines\transforms.py", line 469, in call
results[key], direction=results['flip_direction'])
KeyError: 'img'

Additional information

This is just a test dataset, so the issue of mAP=0 can be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant