diff --git a/configs/localization/bsn/metafile.yml b/configs/localization/bsn/metafile.yml index 693c67d840..9481c15334 100644 --- a/configs/localization/bsn/metafile.yml +++ b/configs/localization/bsn/metafile.yml @@ -8,9 +8,7 @@ Collections: Models: - Name: bsn_400x100_1xb16_20e_activitynet_feature (cuhk_mean_100) Config: - - configs/localization/bsn/bsn_tem_1xb16-400x100-20e_activitynet-feature.py - - configs/localization/bsn/bsn_pgm_400x100_activitynet-feature.py - - configs/localization/bsn/bsn_pem_1xb16-400x100-20e_activitynet-feature.py + configs/localization/bsn/bsn_pem_1xb16-400x100-20e_activitynet-feature.py In Collection: BSN Metadata: Batch Size: 16 @@ -18,6 +16,10 @@ Models: Training Data: ActivityNet v1.3 Training Resources: 1 GPU feature: cuhk_mean_100 + configs: + - configs/localization/bsn/bsn_tem_1xb16-400x100-20e_activitynet-feature.py + - configs/localization/bsn/bsn_pgm_400x100_activitynet-feature.py + - configs/localization/bsn/bsn_pem_1xb16-400x100-20e_activitynet-feature.py Modality: RGB Results: - Dataset: ActivityNet v1.3 diff --git a/docs/en/user_guides/finetune.md b/docs/en/user_guides/finetune.md index a41bcf3a49..20482f0e63 100644 --- a/docs/en/user_guides/finetune.md +++ b/docs/en/user_guides/finetune.md @@ -45,7 +45,7 @@ model = dict( MMAction2 supports UCF101, Kinetics-400, Moments in Time, Multi-Moments in Time, THUMOS14, Something-Something V1&V2, ActivityNet Dataset. The users may need to adapt one of the above datasets to fit their special datasets. -You could refer to [Prepare Dataset](prepare_dataset.md) and [Customize Datast](../advanced_guides/customize_dataset.md) for more details. +You could refer to [Prepare Dataset](prepare_dataset.md) and [Customize Dataset](../advanced_guides/customize_dataset.md) for more details. In our case, UCF101 is already supported by various dataset types, like `VideoDataset`, so we change the config as follows. diff --git a/projects/__init__.py b/projects/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/projects/basic_tad/README.md b/projects/basic_tad/README.md new file mode 100644 index 0000000000..2b326bfdfd --- /dev/null +++ b/projects/basic_tad/README.md @@ -0,0 +1,133 @@ +# BasicTAD + +This project implement the BasicTAD model in MMAction2. Please refer to the [official repo](https://github.com/MCG-NJU/BasicTAD) and [paper](https://arxiv.org/abs/2205.02717) for details. + + +## Usage + +### Setup Environment + +Please refer to [Get Started](https://mmaction2.readthedocs.io/en/latest/get_started/installation.html) to install MMAction2 and MMDetection. + +At first, add the current folder to `PYTHONPATH`, so that Python can find your code. Run command in the current directory to add it. + +> Please run it every time after you opened a new shell. + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Data Preparation + +Prepare the THUMOS14 dataset according to the [instruction](https://github.com/open-mmlab/mmaction2/blob/main/tools/data/thumos14/README.md). + +### Training commands + +**To train with single GPU:** + +```bash +mim train mmaction configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py +``` + +**To train with multiple GPUs:** + +```bash +mim train mmaction configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py --launcher pytorch --gpus 8 +``` + +**To train with multiple GPUs by slurm:** + +```bash +mim train mmaction configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py --launcher slurm \ + --gpus 8 --gpus-per-node 8 --partition $PARTITION +``` + +### Testing commands + +**To test with single GPU:** + +```bash +mim test mmaction configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py --checkpoint $CHECKPOINT +``` + +**To test with multiple GPUs:** + +```bash +mim test configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8 +``` + +**To test with multiple GPUs by slurm:** + +```bash +mim test mmaction configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py --checkpoint $CHECKPOINT --launcher slurm \ + --gpus 8 --gpus-per-node 8 --partition $PARTITION +``` + +> Replace the $CHECKPOINT with the trained model path, e.g., work_dirs/basicTAD_slowonly_96x10_1200e_thumos14_rgb/latest.pth. + +## Results +### THMOS14 +| frame sampling strategy | resolution | gpus | backbone | pretrain | mAP@0.5 | avg. mAP | testing protocol | config | ckpt | log | +| :---------------------: | :--------: | :--: | :------: | :------: |:-------:|:--------:| :----------------: | :-------------------------------------------: | -------------------------------------: | -----------------------------: | +| 1x96x10 | 112x112 | 2 | SlowOnly | Kinetics | 50.4 | 47.9 | 1 clips x 1 crop | [config](./configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py) | todo | todo | + +> Due to the limit of the computing resources, we only train the model in a simple setting (in terms of spatial-temporal resolution, testing augmentation, etc.). To reproduce the results in the paper, please refer to [setting](https://github.com/MCG-NJU/BasicTAD/blob/main/configs/trainval/basictad/thumos14/basictad_slowonly_e700_thumos14_rgb_192win_anchor_based.py) used the official repo. + +> In fact, the main idea of [BasicTAD](https://arxiv.org/abs/2205.02717) lies on its modular design rather than innovating some sophisticated architecture/modules. + +> Currently we only support anchor-based basicTAD model on THUMOS14. The anchor-free version is in the plan. + +> `avg. mAP` refer to the averaged mAP over IoU=(0.3, 0.4, 0.5, 0.6, 0.7). +## Citation + + + +```bibtex +@article{yang2023basictad, + title={Basictad: an astounding rgb-only baseline for temporal action detection}, + author={Yang, Min and Chen, Guo and Zheng, Yin-Dong and Lu, Tong and Wang, Limin}, + journal={Computer Vision and Image Understanding}, + volume={232}, + pages={103692}, + year={2023}, + publisher={Elsevier} +} +``` + +## Checklist + +Here is a checklist of this project's progress, and you can ignore this part if you don't plan to contribute to MMAction2 projects. + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [ ] Converted checkpoint and results (Only for reproduction) + + + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training results + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Unit tests + + + + - [ ] Code style + + + + - [ ] `metafile.yml` and `README.md` + + diff --git a/projects/basic_tad/configs/basicTAD_slowonly_192x5_1200e_thumos14_rgb.py b/projects/basic_tad/configs/basicTAD_slowonly_192x5_1200e_thumos14_rgb.py new file mode 100644 index 0000000000..9d0d97f4bb --- /dev/null +++ b/projects/basic_tad/configs/basicTAD_slowonly_192x5_1200e_thumos14_rgb.py @@ -0,0 +1,68 @@ +_base_ = ['./basicTAD_slowonly_96x10_1200e_thumos14_rgb.py'] +# model settings +model = dict( + neck=[ + dict(type='MaxPool3d', kernel_size=(2, 1, 1), stride=(2, 1, 1)), + dict(type='VDM', + in_channels=2048, + out_channels=512, + conv_cfg=dict(type='Conv3d'), + norm_cfg=dict(type='SyncBN'), + kernel_sizes=(3, 1, 1), + strides=(2, 1, 1), + paddings=(1, 0, 0), + stage_layers=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3, 4), + out_pooling=True), + dict(type='mmdet.FPN', + in_channels=[2048, 512, 512, 512, 512], + out_channels=256, + num_outs=5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='SyncBN'))], + bbox_head=dict(anchor_generator=dict(strides=[2, 4, 8, 16, 32]))) + +clip_len = 192 +frame_interval = 5 +img_shape = (112, 112) +img_shape_test = (128, 128) + +train_pipeline = [ + dict(type='Time2Frame'), + dict(type='TemporalRandomCrop', + clip_len=clip_len, + frame_interval=frame_interval, + iof_th=0.75), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(128, -1), keep_ratio=True), # scale images' short-side to 128, keep aspect ratio + dict(type='SpatialRandomCrop', crop_size=img_shape), + dict(type='Flip', flip_ratio=0.5), + dict(type='PhotoMetricDistortion', + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18, + p=0.5), + dict(type='Rotate', + limit=(-45, 45), + border_mode='reflect_101', + p=0.5), + dict(type='Pad', size=(clip_len, *img_shape)), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackTadInputs', + meta_keys=('img_id', 'img_shape', 'pad_shape', 'scale_factor',)) +] + +val_pipeline = [ + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(128, -1), keep_ratio=True), + dict(type='SpatialCenterCrop', crop_size=img_shape_test), + dict(type='Pad', size=(clip_len, *img_shape_test)), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackTadInputs', + meta_keys=('img_id', 'img_shape', 'scale_factor', 'offset_sec')) +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(clip_len=clip_len, frame_interval=frame_interval, pipeline=val_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/basic_tad/configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py b/projects/basic_tad/configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py new file mode 100644 index 0000000000..6ba0a1e663 --- /dev/null +++ b/projects/basic_tad/configs/basicTAD_slowonly_96x10_1200e_thumos14_rgb.py @@ -0,0 +1,237 @@ +# model settings +model = dict(type='mmdet.SingleStageDetector', + backbone=dict(type='SlowOnly'), + neck=[ + dict( + type='VDM', + in_channels=2048, + out_channels=512, + conv_cfg=dict(type='Conv3d'), + norm_cfg=dict(type='SyncBN'), + kernel_sizes=(3, 1, 1), + strides=(2, 1, 1), + paddings=(1, 0, 0), + stage_layers=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3, 4), + out_pooling=True), + dict(type='mmdet.FPN', + in_channels=[2048, 512, 512, 512, 512], + out_channels=256, + num_outs=5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='SyncBN'))], + bbox_head=dict( + type='RetinaHead1D', + num_classes=20, + in_channels=256, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='SyncBN'), + anchor_generator=dict( + type='mmdet.AnchorGenerator', + octave_base_scale=2, + scales_per_octave=5, + ratios=[1.0], + strides=[1, 2, 4, 8, 16]), + bbox_coder=dict( + type='mmdet.DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + reg_decoded_bbox=True, + loss_cls=dict(type='mmdet.FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), + loss_bbox=dict(type='DIoU1DLoss', loss_weight=1.0), + init_cfg=dict( + type='Normal', + layer='Conv1d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01))), + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + format_shape='NCTHW'), + train_cfg=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1, + ignore_wrt_candidates=True, + iou_calculator=dict(type='BboxOverlaps1D')), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict(nms_pre=300, score_thr=0.005)) # we perform NMS in Metric rather than in the model + +# dataset settings +data_root = 'data/thumos14' # Root path to data for training +data_prefix_train = 'rawframes/val' # path to data for training +data_prefix_val = 'rawframes/test' # path to data for validation and testing +ann_file_train = 'annotations/basicTAD/val.json' # Path to the annotation file for training +ann_file_val = 'annotations/basicTAD/test.json' # Path to the annotation file for validation +ann_file_test = ann_file_val + +clip_len = 96 +frame_interval = 10 +img_shape = (112, 112) +img_shape_test = (128, 128) +overlap_ratio = 0.25 + +train_pipeline = [ + dict(type='Time2Frame'), + dict(type='TemporalRandomCrop', + clip_len=clip_len, + frame_interval=frame_interval, + iof_th=0.75), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(128, -1), keep_ratio=True), # scale images' short-side to 128, keep aspect ratio + dict(type='SpatialRandomCrop', crop_size=img_shape), + dict(type='Flip', flip_ratio=0.5), + dict(type='PhotoMetricDistortion', + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18, + p=0.5), + dict(type='Rotate', + limit=(-45, 45), + border_mode='reflect_101', + p=0.5), + dict(type='Pad', size=(clip_len, *img_shape)), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackTadInputs', + meta_keys=('img_id', 'img_shape', 'pad_shape', 'scale_factor',)) +] +val_pipeline = [ + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(128, -1), keep_ratio=True), + dict(type='SpatialCenterCrop', crop_size=img_shape_test), + dict(type='Pad', size=(clip_len, *img_shape_test)), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackTadInputs', + meta_keys=('img_id', 'img_shape', 'scale_factor', 'offset_sec')) +] +# test_pipeline = val_pipeline + +train_dataloader = dict( # Config of train dataloader + batch_size=2, # Batch size of each single GPU during training + num_workers=6, # Workers to pre-fetch data for each single GPU during training + persistent_workers=True, + # If `True`, the dataloader will not shut down the worker processes after an epoch end, which can accelerate training speed + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( # Config of train dataset + type='Thumos14Dataset', + filename_tmpl='img_{:05}.jpg', + ann_file=ann_file_train, # Path of annotation file + data_root=data_root, # Root path to data, including both frames and ann_file + data_prefix=dict(imgs=data_prefix_train), # Prefix of specific data, e.g., frames and ann_file + pipeline=train_pipeline)) +val_dataloader = dict( # Config of validation dataloader + batch_size=1, # Batch size of each single GPU during validation + num_workers=6, # Workers to pre-fetch data for each single GPU during validation + persistent_workers=True, # If `True`, the dataloader will not shut down the worker processes after an epoch end + sampler=dict(type='DefaultSampler', shuffle=False), # Not shuffle during validation and testing + # DefaultSampler which supports both distributed and non-distributed training. Refer to https://github.com/open-mmlab/mmengine/blob/main/mmengine/dataset/sampler.py) # Randomly shuffle the training data in each epoch + dataset=dict( # Config of validation dataset + type='Thumos14ValDataset', + clip_len=clip_len, frame_interval=frame_interval, overlap_ratio=overlap_ratio, + filename_tmpl='img_{:05}.jpg', + ann_file=ann_file_val, # Path of annotation file + data_root=data_root, + data_prefix=dict(imgs=data_prefix_val), # Prefix of specific data components + pipeline=val_pipeline, + test_mode=True)) +test_dataloader = val_dataloader + +# evaluation settings +val_evaluator = dict( # My customized evaluator for mean average precision + type='TADmAPMetric', + metric='mAP', + iou_thrs=[0.3, 0.4, 0.5, 0.6, 0.7], + nms_cfg=dict(type='nmw', iou_thr=0.6)) +test_evaluator = val_evaluator # Config of testing evaluator + +train_cfg = dict( # Config of training loop + type='EpochBasedTrainLoop', # Name of training loop + max_epochs=1200, # Total training epochs + val_begin=1, # The epoch that begins validating + val_interval=100) # Validation interval +val_cfg = dict( # Config of validation loop + type='ValLoop') # Name of validation loop +test_cfg = dict( # Config of testing loop + type='TestLoop') # Name of testing loop + +# learning policy +param_scheduler = [ # Parameter scheduler for updating optimizer parameters, support dict or list + # Linear learning rate warm-up scheduler + dict(type='LinearLR', + start_factor=0.1, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict(type='CosineRestartLR', # Decays the learning rate once the number of epoch reaches one of the milestones + periods=[100] * 12, + restart_weights=[1] * 12, + eta_min=1e-4, # The min_lr, note it's NOT the min_lr_ratio + by_epoch=True, + begin=40, + end=1240, + convert_to_iter_based=True)] # Convert to update by iteration. + +# optimizer +optim_wrapper = dict( # Config of optimizer wrapper + type='OptimWrapper', # Name of optimizer wrapper, switch to AmpOptimWrapper to enable mixed precision training + optimizer=dict( + # Config of optimizer. Support all kinds of optimizers in PyTorch. Refer to https://pytorch.org/docs/stable/optim.html#algorithms + type='SGD', # Name of optimizer + lr=0.01, # Learning rate + momentum=0.9, # Momentum factor + weight_decay=0.0001), # Weight decay + clip_grad=dict(max_norm=40, norm_type=2)) # Config of gradient clip +auto_scale_lr = dict(enable=False, base_batch_size=16) # The lr=0.01 is for batch_size=16. + +# runtime settings +# imports +custom_imports = dict(imports=['models'], allow_failed_imports=False) +default_scope = 'mmaction' # The default registry scope to find modules. Refer to https://mmengine.readthedocs.io/en/latest/tutorials/registry.html +default_hooks = dict( # Hooks to execute default actions like updating model parameters and saving checkpoints. + runtime_info=dict(type='RuntimeInfoHook'), # The hook to updates runtime information into message hub + timer=dict(type='IterTimerHook'), # The logger used to record time spent during iteration + logger=dict( + type='LoggerHook', # The logger used to record logs during training/validation/testing phase + interval=20, # Interval to print the log + ignore_last=False, + interval_exp_name=1000), # Ignore the log of last iterations in each epoch + param_scheduler=dict(type='ParamSchedulerHook'), # The hook to update some hyper-parameters in optimizer + checkpoint=dict( + type='CheckpointHook', # The hook to save checkpoints periodically + interval=100, # The saving period + save_best='auto', # Specified metric to mearsure the best checkpoint during evaluation + max_keep_ckpts=12), # The maximum checkpoints to keep + sampler_seed=dict(type='DistSamplerSeedHook'), # Data-loading sampler for distributed training + sync_buffers=dict(type='SyncBuffersHook')) # Synchronize model buffers at the end of each epoch +env_cfg = dict( # Dict for setting environment + cudnn_benchmark=False, # Whether to enable cudnn benchmark + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # Parameters to setup multiprocessing + dist_cfg=dict(backend='nccl')) # Parameters to setup distributed environment, the port can also be set + +log_processor = dict( + type='LogProcessor', # Log processor used to format log information + window_size=20, # Default smooth interval + by_epoch=True) # Whether to format logs with epoch type +vis_backends = [ # List of visualization backends + dict(type='LocalVisBackend'), + dict(type='TensorboardVisBackend')] # Local visualization backend +visualizer = dict( # Config of visualizer + type='ActionVisualizer', # Name of visualizer + vis_backends=vis_backends) +# randomness = dict(seed=10, deterministic=True) +# find_unused_parameters = True +log_level = 'INFO' # The level of logging +load_from = None # Load model checkpoint as a pre-trained model from a given path. This will not resume training. +resume = False # Whether to resume from the checkpoint defined in `load_from`. If `load_from` is None, it will resume the latest checkpoint in the `work_dir`. diff --git a/projects/basic_tad/models/__init__.py b/projects/basic_tad/models/__init__.py new file mode 100644 index 0000000000..4e943a69dc --- /dev/null +++ b/projects/basic_tad/models/__init__.py @@ -0,0 +1,2 @@ +from .datasets import * +from .models import * diff --git a/projects/basic_tad/models/datasets/__init__.py b/projects/basic_tad/models/datasets/__init__.py new file mode 100644 index 0000000000..195c0f08c0 --- /dev/null +++ b/projects/basic_tad/models/datasets/__init__.py @@ -0,0 +1,4 @@ +from .thumos14 import Thumos14Dataset +from .thumos14_val import Thumos14ValDataset +from .transforms import * +from .tad_map_metric import TADmAPMetric diff --git a/projects/basic_tad/models/datasets/tad_map_metric.py b/projects/basic_tad/models/datasets/tad_map_metric.py new file mode 100644 index 0000000000..e988af5863 --- /dev/null +++ b/projects/basic_tad/models/datasets/tad_map_metric.py @@ -0,0 +1,226 @@ +import copy +import warnings +from collections import OrderedDict +from typing import Sequence + +import numpy as np +import torch +from mmcv.ops import batched_nms +from mmdet.evaluation.functional import eval_map, eval_recalls +from mmdet.evaluation.metrics import VOCMetric +from mmdet.structures.bbox import bbox_overlaps +from mmengine.logging import MMLogger +from mmengine.structures import InstanceData + +from mmaction.registry import METRICS +from ..models.task_modules.segments_ops import batched_nmw + + +@METRICS.register_module() +class TADmAPMetric(VOCMetric): + + def __init__(self, + nms_cfg=dict(type='nms', iou_thr=0.4), + max_per_video: int = False, + score_thr=0.0, + duration_thr=0.0, + nms_in_overlap=False, + eval_mode: str = 'area', + **kwargs): + super().__init__(eval_mode=eval_mode, **kwargs) + self.nms_cfg = nms_cfg + self.max_per_video = max_per_video + self.score_thr = score_thr + self.duration_thr = duration_thr + self.nms_in_overlap = nms_in_overlap + if nms_cfg.get('type') in ['nms', 'soft_nms']: + self.nms = batched_nms + elif nms_cfg.get('type') == 'nmw': + warnings.warn(f'NMW is used, which is slow compared to NMS as it is not optimized, implemented by Python.') + self.nms = batched_nmw + else: + NotImplementedError(f'NMS type {nms_cfg.get("type")} is not implemented.') + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + data = copy.deepcopy(data_sample) + gts, dets = data['gt_instances'], data['pred_instances'] + gts_ignore = data.get('ignored_instances', dict()) + ann = dict( + video_name=data['img_id'], # for the purpose of future grouping detections of same video. + labels=gts['labels'].cpu().numpy(), + bboxes=gts['bboxes'].cpu().numpy(), + bboxes_ignore=gts_ignore.get('bboxes', torch.empty((0, 4))).cpu().numpy(), + labels_ignore=gts_ignore.get('labels', torch.empty(0, )).cpu().numpy()) + + if self.nms_in_overlap: + ann['overlap'] = data['overlap'], # for the purpose of NMS on overlapped region in testing videos + + # Convert the format of segment predictions from feature-unit to second-unit (add window-offset back first). + if 'offset_sec' in data: + dets['bboxes'] = dets['bboxes'] + data['offset_sec'] + + # Set y1, y2 of predictions the fixed value. + dets['bboxes'][:, 1] = 0.1 + dets['bboxes'][:, 3] = 0.9 + + # Filter out predictions with low scores + valid_inds = dets['scores'] > self.score_thr + + # Filter out predictions with short duration + valid_inds &= (dets['bboxes'][:, 2] - dets['bboxes'][:, 0]) > self.duration_thr + + dets['bboxes'] = dets['bboxes'][valid_inds].cpu() + dets['scores'] = dets['scores'][valid_inds].cpu() + dets['labels'] = dets['labels'][valid_inds].cpu() + + # Format predictions to InstanceData + dets = InstanceData(**dets) + + self.results.append((ann, dets)) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + gts, preds = zip(*results) + + # Following the TadTR, we cropped temporally OVERLAPPED sub-videos from the test video + # to handle test video of long duration while keep a fine temporal granularity. + # In this case, we need perform non-maximum suppression (NMS) to remove redundant detections. + # This NMS, however, is NOT necessary when window stride >= window size, i.e., non-overlapped sliding window. + logger.info(f'\n Concatenating the testing results ...') + gts, preds = self.merge_results_of_same_video(gts, preds) + preds = self.non_maximum_suppression(preds) + eval_results = OrderedDict() + if self.metric == 'mAP': + assert isinstance(self.iou_thrs, list) + dataset_name = self.dataset_meta['classes'] + + mean_aps = [] + for iou_thr in self.iou_thrs: + logger.info(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') + mean_ap, _ = eval_map( + preds, + gts, + scale_ranges=self.scale_ranges, + iou_thr=iou_thr, + dataset=dataset_name, + logger=logger, + eval_mode=self.eval_mode, + use_legacy_coordinate=False) + mean_aps.append(mean_ap) + eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) + eval_results['mAP'] = sum(mean_aps) / len(mean_aps) + eval_results.move_to_end('mAP', last=False) + elif self.metric == 'recall': + # TODO: Currently not checked. + gt_bboxes = [ann['bboxes'] for ann in self.annotations] + recalls = eval_recalls( + gt_bboxes, + results, + self.proposal_nums, + self.iou_thrs, + logger=logger, + use_legacy_coordinate=False) + for i, num in enumerate(self.proposal_nums): + for j, iou_thr in enumerate(self.iou_thrs): + eval_results[f'recall@{num}@{iou_thr}'] = recalls[i, j] + if recalls.shape[1] > 1: + ar = recalls.mean(axis=1) + for i, num in enumerate(self.proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + return eval_results + + @staticmethod + def merge_results_of_same_video(gts, preds): + # Merge prediction results from the same videos because we use sliding windows to crop the testing videos + # Also known as the Cross-Window Fusion (CWF) + video_names = list(dict.fromkeys([gt['video_name'] for gt in gts])) + + merged_gts_dict = dict() + merged_preds_dict = dict() + for this_gt, this_pred in zip(gts, preds): + video_name = this_gt.pop('video_name') + # Computer the mask indicating that if a prediction is in the overlapped regions. + overlap_regions = this_gt.pop('overlap', np.empty([0])) + if overlap_regions.size == 0: + this_pred.in_overlap = np.zeros(this_pred.bboxes.shape[0], dtype=bool) + else: + this_pred.in_overlap = bbox_overlaps(this_pred.bboxes, torch.from_numpy(overlap_regions)) > 0 + + merged_preds_dict.setdefault(video_name, []).append(this_pred) + merged_gts_dict.setdefault(video_name, this_gt) # the gt is video-wise thus no need concatenation + + # dict of list to list of dict + merged_gts = [] + merged_preds = [] + for video_name in video_names: + merged_gts.append(merged_gts_dict[video_name]) + # Concatenate detection in windows of the same video + merged_preds.append(InstanceData.cat(merged_preds_dict[video_name])) + return merged_gts, merged_preds + + def non_maximum_suppression(self, preds): + preds_nms = [] + for pred_v in preds: + if self.nms_cfg is not None: + if self.nms_in_overlap: + if pred_v.in_overlap.sum() > 1: + # Perform NMS over predictions in each overlapped region + pred_not_in_overlaps = pred_v[~pred_v.in_overlap.max(-1)[0]] + pred_in_overlaps = [] + for i in range(pred_v.in_overlap.shape[1]): + pred_in_overlap = pred_v[pred_v.in_overlap[:, i]] + if len(pred_in_overlap) == 0: + continue + bboxes, keep_idxs = self.nms(pred_in_overlap.bboxes, + pred_in_overlap.scores, + pred_in_overlap.labels, + nms_cfg=self.nms_cfg) + pred_in_overlap = pred_in_overlap[keep_idxs] + pred_in_overlap.bboxes = bboxes[:, :-1] + pred_in_overlap.scores = bboxes[:, -1] + pred_in_overlaps.append(pred_in_overlap) + pred_v = InstanceData.cat(pred_in_overlaps + [pred_not_in_overlaps]) + else: + bboxes, keep_idxs = self.nms(pred_v.bboxes, + pred_v.scores, + pred_v.labels, + nms_cfg=self.nms_cfg) + pred_v = pred_v[keep_idxs] + # Some NMS operations will change the value of scores and bboxes, we track it. + pred_v.bboxes = bboxes[:, :-1] + pred_v.scores = bboxes[:, -1] + sort_idxs = pred_v.scores.argsort(descending=True) + pred_v = pred_v[sort_idxs] + # keep top-k predictions + if self.max_per_video: + pred_v = pred_v[:self.max_per_video] + + # Reformat predictions to meet the requirement of eval_map function: VideoList[ClassList[PredictionArray]] + dets = [] + for label in range(len(self.dataset_meta['classes'])): + index = np.where(pred_v.labels == label)[0] + pred_bbox_with_scores = np.hstack( + [pred_v[index].bboxes, pred_v[index].scores.reshape((-1, 1))]) + dets.append(pred_bbox_with_scores) + + preds_nms.append(dets) + return preds_nms diff --git a/projects/basic_tad/models/datasets/thumos14.py b/projects/basic_tad/models/datasets/thumos14.py new file mode 100644 index 0000000000..cbeeaff778 --- /dev/null +++ b/projects/basic_tad/models/datasets/thumos14.py @@ -0,0 +1,92 @@ +# adapted from basicTAD +import re +import warnings +from pathlib import Path + +import mmengine +import numpy as np +from mmaction.datasets import BaseActionDataset +from mmaction.registry import DATASETS + + +def make_regex_pattern(fixed_pattern): + # Use regular expression to extract number of digits + num_digits = re.search(r'\{:(\d+)\}', fixed_pattern).group(1) + # Build the pattern string using the extracted number of digits + pattern = fixed_pattern.replace('{:' + num_digits + '}', r'\d{' + num_digits + '}') + return pattern + + +@DATASETS.register_module() +class Thumos14Dataset(BaseActionDataset): + """Thumos14 dataset for temporal action detection.""" + + metainfo = dict(classes=('BaseballPitch', 'BasketballDunk', 'Billiards', 'CleanAndJerk', + 'CliffDiving', 'CricketBowling', 'CricketShot', 'Diving', + 'FrisbeeCatch', 'GolfSwing', 'HammerThrow', 'HighJump', + 'JavelinThrow', 'LongJump', 'PoleVault', 'Shotput', + 'SoccerPenalty', 'TennisSwing', 'ThrowDiscus', + 'VolleyballSpiking')) + + def __init__(self, filename_tmpl='img_{:05}.jpg', **kwargs): + self.filename_tmpl = filename_tmpl + super(Thumos14Dataset, self).__init__(**kwargs) + + def load_data_list(self): + data_list = [] + data = mmengine.load(self.ann_file) + for video_name, video_info in data['database'].items(): + # Meta information + frame_dir = Path(self.data_prefix['imgs']).joinpath(video_name) + if not frame_dir.exists(): + warnings.warn(f'{frame_dir} does not exist.') + continue + pattern = make_regex_pattern(self.filename_tmpl) + imgfiles = [img for img in frame_dir.iterdir() if re.fullmatch(pattern, img.name)] + num_imgs = len(imgfiles) + + data_info = dict(video_name=video_name, + frame_dir=str(frame_dir), + duration=float(video_info['duration']), + total_frames=num_imgs, + filename_tmpl=self.filename_tmpl, + fps=int(round(num_imgs / video_info['duration']))) + + # Segments information + segments = [] + labels = [] + ignore_flags = [] + for ann in video_info['annotations']: + label = ann['label'] + segment = ann['segment'] + + if not self.test_mode: + segment[0] = min(video_info['duration'], max(0, segment[0])) + segment[1] = min(video_info['duration'], max(0, segment[1])) + if segment[0] >= segment[1]: + continue + + if label in self.metainfo['classes']: + ignore_flags.append(0) + labels.append(self.metainfo['classes'].index(label)) + else: + ignore_flags.append(1) + labels.append(-1) + segments.append(segment) + + if not segments or np.all(ignore_flags): + warnings.warn(f'No valid segments found in video {video_name}. Excluded') + continue + + data_info.update(dict( + segments=np.array(segments, dtype=np.float32), + labels=np.array(labels, dtype=np.int64), + ignore_flags=np.array(ignore_flags, dtype=np.float32))) + + data_list.append(data_info) + + # standard_ann_file = dict() + # standard_ann_file['metainfo'] = dict(classes=self.CLASSES) + # standard_ann_file['data_list'] = data_list + # mmengine.dump(standard_ann_file, 'train.json') + return data_list diff --git a/projects/basic_tad/models/datasets/thumos14_val.py b/projects/basic_tad/models/datasets/thumos14_val.py new file mode 100644 index 0000000000..ffd58d83dc --- /dev/null +++ b/projects/basic_tad/models/datasets/thumos14_val.py @@ -0,0 +1,121 @@ +# adapted from basicTAD +import re +import warnings +from pathlib import Path + +import mmengine +import numpy as np + +from mmaction.datasets import BaseActionDataset +from mmaction.registry import DATASETS + + +def make_regex_pattern(fixed_pattern): + # Use regular expression to extract number of digits + num_digits = re.search(r'\{:(\d+)\}', fixed_pattern).group(1) + # Build the pattern string using the extracted number of digits + pattern = fixed_pattern.replace('{:' + num_digits + '}', r'\d{' + num_digits + '}') + return pattern + + +@DATASETS.register_module() +class Thumos14ValDataset(BaseActionDataset): + """Thumos14 dataset for temporal action detection.""" + + metainfo = dict(classes=('BaseballPitch', 'BasketballDunk', 'Billiards', 'CleanAndJerk', + 'CliffDiving', 'CricketBowling', 'CricketShot', 'Diving', + 'FrisbeeCatch', 'GolfSwing', 'HammerThrow', 'HighJump', + 'JavelinThrow', 'LongJump', 'PoleVault', 'Shotput', + 'SoccerPenalty', 'TennisSwing', 'ThrowDiscus', + 'VolleyballSpiking')) + + def __init__(self, clip_len=96, frame_interval=10, overlap_ratio=0.25, filename_tmpl='img_{:05}.jpg', **kwargs): + self.filename_tmpl = filename_tmpl + assert 0 <= overlap_ratio < 1 + self.clip_len = clip_len + self.frame_interval = frame_interval + self.overlap_ratio = overlap_ratio + + self.ori_clip_len = (self.clip_len - 1) * self.frame_interval + 1 + self.stride = int(self.ori_clip_len * (1 - self.overlap_ratio)) + + super(Thumos14ValDataset, self).__init__(**kwargs) + + def load_data_list(self): + data_list = [] + data = mmengine.load(self.ann_file) + for video_name, video_info in data['database'].items(): + # Segments information + segments = [] + labels = [] + ignore_flags = [] + for ann in video_info['annotations']: + label = ann['label'] + segment = ann['segment'] + + if not self.test_mode: + segment[0] = min(video_info['duration'], max(0, segment[0])) + segment[1] = min(video_info['duration'], max(0, segment[1])) + if segment[0] >= segment[1]: + continue + + if label in self.metainfo['classes']: + ignore_flags.append(0) + labels.append(self.metainfo['classes'].index(label)) + else: + ignore_flags.append(1) + labels.append(-1) + segments.append(segment) + if not segments: + segments = np.zeros((0, 2)) + labels = np.zeros((0,)) + ignore_flags = np.zeros((0,)) + else: + segments = np.array(segments) + labels = np.array(labels) + ignore_flags = np.array(ignore_flags) + + # Meta information + frame_dir = Path(self.data_prefix['imgs']).joinpath(video_name) + if not frame_dir.exists(): + warnings.warn(f'{frame_dir} does not exist.') + continue + pattern = make_regex_pattern(self.filename_tmpl) + imgfiles = [img for img in frame_dir.iterdir() if re.fullmatch(pattern, img.name)] + num_imgs = len(imgfiles) + + total_frames = num_imgs + offset = 0 + idx = 1 + while True: + if offset < total_frames - 1: + clip = offset + np.arange(self.clip_len) * self.frame_interval + clip = clip[clip < total_frames] + fps = round(num_imgs / video_info['duration']) + data_info = dict(video_name=f'{video_name}', + frame_dir=str(frame_dir), + duration=float(video_info['duration']), + total_frames=num_imgs, + filename_tmpl=self.filename_tmpl, + fps=fps, + frame_inds=clip, + offset_sec=offset / fps, + tsize=len(clip), + num_clips=1, + clip_len=self.clip_len, + frame_interval=self.frame_interval, + tscale_factor=fps / self.frame_interval, + segments=segments.astype(np.float32), + labels=labels.astype(np.int64), + ignore_flags=ignore_flags.astype(np.float32)) + data_list.append(data_info) + offset += self.stride + idx += 1 + else: + break + + # standard_ann_file = dict() + # standard_ann_file['metainfo'] = dict(classes=self.CLASSES) + # standard_ann_file['data_list'] = data_list + # mmengine.dump(standard_ann_file, 'train.json') + return data_list diff --git a/projects/basic_tad/models/datasets/transforms.py b/projects/basic_tad/models/datasets/transforms.py new file mode 100644 index 0000000000..64c49c0929 --- /dev/null +++ b/projects/basic_tad/models/datasets/transforms.py @@ -0,0 +1,554 @@ +# adapted from basicTAD +# https://github.com/open-mmlab/mmcv or +# https://github.com/open-mmlab/mmdetection + +from typing import Sequence + +import mmcv +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmdet.structures import DetDataSample +from mmengine.structures import InstanceData +from numpy import random + +from mmaction.registry import TRANSFORMS +from ..models.task_modules.segments_ops import segment_overlaps + + +@TRANSFORMS.register_module() +class Time2Frame(BaseTransform): + """Switch time point to frame index.""" + + def transform(self, results): + results['segments'] = results['segments'] * results['fps'] + + return results + + +@TRANSFORMS.register_module() +class TemporalRandomCrop(BaseTransform): + """Temporally crop. + + Args: + clip_len (int, optional): The cropped frame num. Default: 768. + iof_th(float, optional): The minimal iof threshold to crop. Default: 0 + """ + + def __init__(self, clip_len=96, frame_interval=10, iof_th=0.75): + self.clip_len = clip_len + self.frame_interval = frame_interval + self.iof_th = iof_th + + def get_valid_mask(self, segments, patch, iof_th): + gt_iofs = segment_overlaps(segments, patch, mode='iof')[:, 0] + patch_iofs = segment_overlaps(patch, segments, mode='iof')[0, :] + iofs = np.maximum(gt_iofs, patch_iofs) + mask = iofs > iof_th + + return mask + + def transform(self, results): + """Call function to random temporally crop video frame. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Temporally cropped results, 'frame_inds' is updated in + result dict. + """ + total_frames = results['total_frames'] + ori_clip_len = (self.clip_len - 1) * self.frame_interval + 1 + ori_clip_len = min(ori_clip_len, total_frames) + while True: + clip = np.arange(self.clip_len) * self.frame_interval + offset = np.random.randint(0, total_frames - ori_clip_len + 1) + clip = clip + offset + clip = clip[clip < total_frames] + start, end = clip[0], clip[-1] + + segments = results['segments'] + mask = self.get_valid_mask(segments, np.array([[start, end]], dtype=np.float32), self.iof_th) + + # If the cropped clip does NOT have IoF greater than the threshold with any (acknowledged) actions, then re-crop. + if not np.logical_and(mask, np.logical_not(results['ignore_flags'])).any(): + continue + + segments = segments[mask] + segments = segments.clip(min=start, max=end) # TODO: Is this necessary? + segments -= start # transform the index of segments to be relative to the cropped segment + segments = segments / self.frame_interval # to be relative to the input clip + assert segments.max() < len(clip) + assert segments.min() >= 0 + + results['segments'] = segments + results['labels'] = results['labels'][mask] + results['ignore_flags'] = results['ignore_flags'][mask] + results['frame_inds'] = clip + assert max(results['frame_inds']) < total_frames, f"offset: {offset}\n" \ + f"start, end: [{start}, {end}]," \ + f"total frames: {total_frames}" + results['num_clips'] = 1 + results['clip_len'] = self.clip_len + results['tsize'] = len(clip) + + if 'img_idx_mapping' in results: + results['frame_inds'] = results['img_idx_mapping'][clip] + assert results['frame_inds'].max() < results['total_frames'] + assert results['frame_inds'].min() >= 0 + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_len={self.clip_len},' + repr_str += f'(frame_interval={self.frame_interval},' + repr_str += f'iof_th={self.iof_th})' + + return repr_str + + +@TRANSFORMS.register_module() +class SpatialRandomCrop(BaseTransform): + """Spatially random crop images. + Args: + crop_size (tuple): Expected size after cropping, (h, w). + Notes: + - If the image is smaller than the crop size, return the original image + """ + + def __init__(self, crop_size): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + + def transform(self, results): + """Call function to randomly crop images. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Randomly cropped results, 'imgs_shape' key in result dict + is updated according to crop size. + """ + img_h, img_w = results['img_shape'] + margin_h = max(img_h - self.crop_size[0], 0) + margin_w = max(img_w - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + # crop images + imgs = [img[crop_y1:crop_y2, crop_x1:crop_x2] for img in results['imgs']] + results['imgs'] = imgs + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@TRANSFORMS.register_module() +class PhotoMetricDistortion(BaseTransform): + """Apply photometric distortion to images sequentially, every + transformation is applied with a probability of 0.5. The position of random + contrast is in second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18, + p=0.5): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + self.p = p + + def transform(self, results): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + imgs = np.array(results['imgs']).astype(np.float32) + + def _filter(img): + img[img < 0] = 0 + img[img > 255] = 255 + return img + + if random.uniform(0, 1) <= self.p: + + # random brightness + if random.randint(2): + delta = random.uniform(-self.brightness_delta, + self.brightness_delta) + imgs += delta + imgs = _filter(imgs) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + imgs *= alpha + imgs = _filter(imgs) + + # convert color from BGR to HSV + imgs = np.array([mmcv.image.bgr2hsv(img) for img in imgs]) + + # random saturation + if random.randint(2): + imgs[..., 1] *= random.uniform(self.saturation_lower, + self.saturation_upper) + + # random hue + # if random.randint(2): + if True: + imgs[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) + imgs[..., 0][imgs[..., 0] > 360] -= 360 + imgs[..., 0][imgs[..., 0] < 0] += 360 + + # convert color from HSV to BGR + imgs = np.array([mmcv.image.hsv2bgr(img) for img in imgs]) + imgs = _filter(imgs) + + # random contrast + if mode == 0: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + imgs *= alpha + imgs = _filter(imgs) + + # randomly swap channels + if random.randint(2): + imgs = imgs[..., random.permutation(3)] + + results['imgs'] = list(imgs) # change back to mmaction-style (list of) imgs + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(\nbrightness_delta={self.brightness_delta},\n' + repr_str += 'contrast_range=' + repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n' + repr_str += 'saturation_range=' + repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n' + repr_str += f'hue_delta={self.hue_delta})' + return repr_str + + +@TRANSFORMS.register_module() +class Rotate(BaseTransform): + """Spatially rotate images. + + Args: + limit (int, list or tuple): Angle range, (min_angle, max_angle). + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos". + Default: bilinear + border_mode (str): Border mode, accepted values are "constant", + "isolated", "reflect", "reflect_101", "replicate", "transparent", + "wrap". Default: constant + border_value (int): Border value. Default: 0 + """ + + def __init__(self, + limit, + interpolation='bilinear', + border_mode='constant', + border_value=0, + p=0.5): + if isinstance(limit, int): + limit = (-limit, limit) + self.limit = limit + self.interpolation = interpolation + self.border_mode = border_mode + self.border_value = border_value + self.p = p + + def transform(self, results): + """Call function to random rotate images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Spatially rotated results. + """ + + if random.uniform(0, 1) <= self.p: + angle = random.uniform(*self.limit) + imgs = [ + mmcv.image.imrotate( + img, + angle=angle, + interpolation=self.interpolation, + border_mode=self.border_mode, + border_value=self.border_value) for img in results['imgs']] + + results['imgs'] = [np.ascontiguousarray(img) for img in imgs] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(limit={self.limit},' + repr_str += f'interpolation={self.interpolation},' + repr_str += f'border_mode={self.border_mode},' + repr_str += f'border_value={self.border_value},' + repr_str += f'p={self.p})' + + return repr_str + + +@TRANSFORMS.register_module() +class Pad(BaseTransform): + """Pad images. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value, 0 by default. + """ + + def __init__(self, size=None, size_divisor=None, pad_val=0): + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + # only one of size and size_divisor should be valid + assert size is not None or size_divisor is not None + assert size is None or size_divisor is None + + @staticmethod + def impad(img, shape, pad_val=0): + """Pad an image or images to a certain shape. + Args: + img (ndarray): Image to be padded. + shape (tuple[int]): Expected padding shape (h, w). + pad_val (Number | Sequence[Number]): Values to be filled in padding + areas. Default: 0. + Returns: + ndarray: The padded image. + """ + if not isinstance(pad_val, (int, float)): + assert len(pad_val) == img.shape[-1] + if len(shape) < len(img.shape): + shape = shape + (img.shape[-1],) + assert len(shape) == len(img.shape) + for s, img_s in zip(shape, img.shape): + assert s >= img_s, f"pad shape {s} should be greater than image shape {img_s}" + pad = np.empty(shape, dtype=img.dtype) + pad[...] = pad_val + pad[:img.shape[0], :img.shape[1], :img.shape[2], ...] = img + return pad + + @staticmethod + def impad_to_multiple(img, divisor, pad_val=0): + """Pad an image to ensure each edge to be multiple to some number. + Args: + img (ndarray): Image to be padded. + divisor (int): Padded image edges will be multiple to divisor. + pad_val (Number | Sequence[Number]): Same as :func:`impad`. + Returns: + ndarray: The padded image. + """ + pad_shape = tuple( + int(np.ceil(shape / divisor)) * divisor for shape in img.shape[:-1]) + return Pad.impad(img, pad_shape, pad_val) + + def transform(self, results): + """Call function to pad images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + if self.size is not None: + padded_imgs = self.impad( + np.array(results['imgs']), shape=self.size, pad_val=self.pad_val) + elif self.size_divisor is not None: + padded_imgs = self.impad_to_multiple( + np.array(results['imgs']), self.size_divisor, pad_val=self.pad_val) + else: + raise AssertionError("Either 'size' or 'size_divisor' need to be set, but both None") + results['imgs'] = list(padded_imgs) # change back to mmaction-style (list of) imgs + results['pad_tsize'] = padded_imgs.shape[0] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, ' + repr_str += f'size_divisor={self.size_divisor}, ' + repr_str += f'pad_val={self.pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class PackTadInputs(BaseTransform): + """Pack the inputs data for the detection / semantic segmentation / + panoptic segmentation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: + + - ``img_id``: id of the image + + - ``img_path``: path to the image file + + - ``ori_shape``: original shape of the image as a tuple (h, w) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``scale_factor``: a float indicating the preprocessing scale + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + Args: + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')`` + """ + + def __init__(self, + meta_keys=('img_id', 'img_shape', 'scale_factor')): + self.meta_keys = meta_keys + + @staticmethod + def mmdet_mapping(results: dict) -> dict: + # Modify the meta keys/values to be consistent with mmdet + results['img'] = results['imgs'] + results['img_shape'] = (1, results.pop('tsize')) + results['pad_shape'] = (1, results.pop('pad_tsize')) + if 'tscale_factor' in results: + results['scale_factor'] = (results.pop('tscale_factor'), 1) # (w, h) + results['img_id'] = results.pop('video_name') + + gt_bboxes = np.insert(results['segments'], 2, 0.9, axis=-1) + gt_bboxes = np.insert(gt_bboxes, 1, 0.1, axis=-1) + results['bboxes'] = gt_bboxes + results['labels'] = results.pop('labels') + + return results + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`DetDataSample`): The annotation info of the + sample. + """ + results = self.mmdet_mapping(results) + packed_results = dict() + + img = results['img'] + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img)) + else: + img = to_tensor(img).contiguous() + + packed_results['inputs'] = img + + data_sample = DetDataSample(gt_instances=InstanceData(bboxes=to_tensor(results['bboxes']), + labels=to_tensor(results['labels']))) + img_meta = {} + for key in self.meta_keys: + assert key in results, f'`{key}` is not found in `results`, ' \ + f'the valid keys are {list(results)}.' + img_meta[key] = results[key] + + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class SpatialCenterCrop(BaseTransform): + """Spatially center crop images. + + Args: + crop_size (tuple): Expected size after cropping, (h, w). + + Notes: + - If the image is smaller than the crop size, return the original image + """ + + def __init__(self, crop_size): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + + def transform(self, results): + """Call function to center crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'imgs_shape' key in result dict + is updated according to crop size. + """ + + imgs = np.array(results['imgs']) + margin_h = max(imgs.shape[1] - self.crop_size[0], 0) + margin_w = max(imgs.shape[2] - self.crop_size[1], 0) + offset_h = int(margin_h / 2) + offset_w = int(margin_w / 2) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + # crop images + imgs = imgs[:, crop_y1:crop_y2, crop_x1:crop_x2, ...] + results['imgs'] = list(imgs) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' diff --git a/projects/basic_tad/models/models/__init__.py b/projects/basic_tad/models/models/__init__.py new file mode 100644 index 0000000000..7d623a9eda --- /dev/null +++ b/projects/basic_tad/models/models/__init__.py @@ -0,0 +1,5 @@ +from .backbones import * +from .heads import * +from .necks import * +from .task_modules import * +from .losses import * diff --git a/projects/basic_tad/models/models/backbones/__init__.py b/projects/basic_tad/models/models/backbones/__init__.py new file mode 100644 index 0000000000..6548515cac --- /dev/null +++ b/projects/basic_tad/models/models/backbones/__init__.py @@ -0,0 +1 @@ +from .slowonly import SlowOnly diff --git a/projects/basic_tad/models/models/backbones/decorator.py b/projects/basic_tad/models/models/backbones/decorator.py new file mode 100644 index 0000000000..b68559abe8 --- /dev/null +++ b/projects/basic_tad/models/models/backbones/decorator.py @@ -0,0 +1,8 @@ +def crops_to_batch(forward_methods): + + def wrapper(self, inputs, *args, **kwargs): + num_crops = inputs.shape[1] + inputs = inputs.view(-1, *inputs.shape[2:]) + return forward_methods(self, inputs, *args, **kwargs) + + return wrapper diff --git a/projects/basic_tad/models/models/backbones/slowonly.py b/projects/basic_tad/models/models/backbones/slowonly.py new file mode 100644 index 0000000000..1e57a71490 --- /dev/null +++ b/projects/basic_tad/models/models/backbones/slowonly.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn + +from mmaction.registry import MODELS +from .decorator import crops_to_batch + + +@MODELS.register_module() +class SlowOnly(nn.Module): + + def __init__(self, + out_indices=(4,), + freeze_bn=True, + freeze_bn_affine=True + ): + super(SlowOnly, self).__init__() + model = torch.hub.load("facebookresearch/pytorchvideo", model='slow_r50', pretrained=True) + self.blocks = model.blocks[:-1] # exclude the last HEAD block + self.out_indices = out_indices + self._freeze_bn = freeze_bn + self._freeze_bn_affine = freeze_bn_affine + + @crops_to_batch + def forward(self, x): + outs = [] + for i, block in enumerate(self.blocks): + x = block(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + return outs + + def train(self, mode=True): + super(SlowOnly, self).train(mode) + if self._freeze_bn and mode: + for name, m in self.named_modules(): + if isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)): + m.eval() + if self._freeze_bn_affine: + m.weight.register_hook(lambda grad: torch.zeros_like(grad)) + m.bias.register_hook(lambda grad: torch.zeros_like(grad)) diff --git a/projects/basic_tad/models/models/heads/__init__.py b/projects/basic_tad/models/models/heads/__init__.py new file mode 100644 index 0000000000..8eda10c236 --- /dev/null +++ b/projects/basic_tad/models/models/heads/__init__.py @@ -0,0 +1 @@ +from .retina_head_1d import RetinaHead1D \ No newline at end of file diff --git a/projects/basic_tad/models/models/heads/retina_head_1d.py b/projects/basic_tad/models/models/heads/retina_head_1d.py new file mode 100644 index 0000000000..0d411f9951 --- /dev/null +++ b/projects/basic_tad/models/models/heads/retina_head_1d.py @@ -0,0 +1,438 @@ +# Adapted from https://github.com/MCG-NJU/BasicTAD/ +# https://github.com/open-mmlab/mmcv or +# https://github.com/open-mmlab/mmdetection + +import copy +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmaction.registry import MODELS +from mmdet.models.dense_heads import RetinaHead +import warnings + + +@MODELS.register_module() +class RetinaHead1D(RetinaHead): + r"""Modified RetinaHead to support 1D + """ + + def _init_layers(self): + super()._init_layers() + self.retina_cls = nn.Conv1d( # --------------- + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + reg_dim = self.bbox_coder.encode_size // 2 + self.retina_reg = nn.Conv1d( # --------------- + self.feat_channels, self.num_base_priors * reg_dim, 3, padding=1) + + def forward_single(self, x): + cls_score, bbox_pred = super().forward_single(x) + # add pseudo H dimension + cls_score, bbox_pred = cls_score.unsqueeze(-2), bbox_pred.unsqueeze(-2) + # bbox_pred = [N, 2], where 2 is the x, w. Now adding pseudo y, h + bbox_pred = bbox_pred.unflatten(1, (self.num_base_priors, -1)) + y, h = torch.split(torch.zeros_like(bbox_pred), 1, dim=2) + bbox_pred = torch.cat((bbox_pred[:, :, :1, :, :], y, bbox_pred[:, :, 1:, :, :], h), dim=2) + bbox_pred = bbox_pred.flatten(start_dim=1, end_dim=2) + return cls_score, bbox_pred + + def predict_by_feat(self, *args, **kwargs): + # As we predict sliding windows of untrimmed videos, we do not perform NMS inside each window but + # leave the NMS performed globally on the whole video. + if kwargs.get('with_nms', False): + warnings.warn("with_nms is True, which is unexpected as we should perform NMS in Metric rather than in model") + else: + kwargs['with_nms'] = False + return super().predict_by_feat(*args, **kwargs) + +# def get_anchors(self, +# featmap_sizes: List[tuple], +# batch_img_metas: List[dict], +# device: Union[torch.device, str] = 'cuda') \ +# -> Tuple[List[List[Tensor]], List[List[Tensor]]]: +# num_imgs = len(batch_img_metas) +# +# # since feature map sizes of all images are the same, we only compute +# # anchors for one time +# multi_level_anchors = self.prior_generator.grid_priors( +# featmap_sizes, device=device) +# anchor_list = [multi_level_anchors for _ in range(num_imgs)] +# +# # for each image, we compute valid flags of multi level anchors +# valid_flag_list = [] +# for img_id, img_meta in enumerate(batch_img_metas): +# multi_level_flags = self.prior_generator.valid_flags( +# featmap_sizes, img_meta['pad_tsize'], device) # --------------------------- +# valid_flag_list.append(multi_level_flags) +# +# return anchor_list, valid_flag_list +# +# def _get_targets_single(self, +# flat_anchors: Union[Tensor, BaseBoxes], +# valid_flags: Tensor, +# gt_instances: InstanceData, +# img_meta: dict, +# gt_instances_ignore: Optional[InstanceData] = None, +# unmap_outputs: bool = True) -> tuple: +# inside_flags = anchor_inside_flags(flat_anchors, valid_flags, # --------------------- +# img_meta['tsize'], # --------------------- +# self.train_cfg['allowed_border']) +# if not inside_flags.any(): +# raise ValueError( +# 'There is no valid anchor inside the image boundary. Please ' +# 'check the image size and anchor sizes, or set ' +# '``allowed_border`` to -1 to skip the condition.') +# # assign gt and sample anchors +# anchors = flat_anchors[inside_flags] +# +# pred_instances = InstanceData(priors=anchors) +# assign_result = self.assigner.assign(pred_instances, gt_instances, +# gt_instances_ignore) +# # No sampling is required except for RPN and +# # Guided Anchoring algorithms +# sampling_result = self.sampler.sample(assign_result, pred_instances, +# gt_instances) +# +# num_valid_anchors = anchors.shape[0] +# target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox \ +# else self.bbox_coder.encode_size +# bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim) +# bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim) +# +# # TODO: Considering saving memory, is it necessary to be long? +# labels = anchors.new_full((num_valid_anchors,), +# self.num_classes, +# dtype=torch.long) +# label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) +# +# pos_inds = sampling_result.pos_inds +# neg_inds = sampling_result.neg_inds +# # `bbox_coder.encode` accepts tensor or box type inputs and generates +# # tensor targets. If regressing decoded boxes, the code will convert +# # box type `pos_bbox_targets` to tensor. +# if len(pos_inds) > 0: +# if not self.reg_decoded_bbox: +# pos_bbox_targets = self.bbox_coder.encode( +# sampling_result.pos_priors, sampling_result.pos_gt_bboxes) +# else: +# pos_bbox_targets = sampling_result.pos_gt_bboxes +# pos_bbox_targets = get_box_tensor(pos_bbox_targets) +# bbox_targets[pos_inds, :] = pos_bbox_targets +# bbox_weights[pos_inds, :] = 1.0 +# +# labels[pos_inds] = sampling_result.pos_gt_labels +# if self.train_cfg['pos_weight'] <= 0: +# label_weights[pos_inds] = 1.0 +# else: +# label_weights[pos_inds] = self.train_cfg['pos_weight'] +# if len(neg_inds) > 0: +# label_weights[neg_inds] = 1.0 +# +# # map up to original set of anchors +# if unmap_outputs: +# num_total_anchors = flat_anchors.size(0) +# labels = unmap( +# labels, num_total_anchors, inside_flags, +# fill=self.num_classes) # fill bg label +# label_weights = unmap(label_weights, num_total_anchors, +# inside_flags) +# bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) +# bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) +# +# return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, +# neg_inds, sampling_result) +# +# def loss_by_feat( +# self, +# cls_scores: List[Tensor], +# bbox_preds: List[Tensor], +# batch_gt_instances: InstanceList, +# batch_img_metas: List[dict], +# batch_gt_instances_ignore: OptInstanceList = None) -> dict: +# featmap_sizes = [featmap.size()[-1] for featmap in cls_scores] # ----------------- +# assert len(featmap_sizes) == self.prior_generator.num_levels +# +# device = cls_scores[0].device +# +# anchor_list, valid_flag_list = self.get_anchors( +# featmap_sizes, batch_img_metas, device=device) +# cls_reg_targets = self.get_targets( +# anchor_list, +# valid_flag_list, +# batch_gt_instances, +# batch_img_metas, +# batch_gt_instances_ignore=batch_gt_instances_ignore) +# (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, +# avg_factor) = cls_reg_targets +# +# # anchor number of multi levels +# num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] +# # concat all level anchors and flags to a single tensor +# concat_anchor_list = [] +# for i in range(len(anchor_list)): +# concat_anchor_list.append(cat_boxes(anchor_list[i])) +# all_anchor_list = images_to_levels(concat_anchor_list, +# num_level_anchors) +# +# losses_cls, losses_bbox = multi_apply( +# self.loss_by_feat_single, +# cls_scores, +# bbox_preds, +# all_anchor_list, +# labels_list, +# label_weights_list, +# bbox_targets_list, +# bbox_weights_list, +# avg_factor=avg_factor) +# return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) +# +# def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, +# anchors: Tensor, labels: Tensor, +# label_weights: Tensor, bbox_targets: Tensor, +# bbox_weights: Tensor, avg_factor: int) -> tuple: +# # classification loss +# labels = labels.reshape(-1) +# label_weights = label_weights.reshape(-1) +# cls_score = cls_score.permute(0, 2, 1).reshape(-1, self.cls_out_channels) # --------------- +# loss_cls = self.loss_cls( +# cls_score, labels, label_weights, avg_factor=avg_factor) +# # regression loss +# target_dim = bbox_targets.size(-1) +# bbox_targets = bbox_targets.reshape(-1, target_dim) +# bbox_weights = bbox_weights.reshape(-1, target_dim) +# bbox_pred = bbox_pred.permute(0, 2, 1).reshape(-1, self.bbox_coder.encode_size) # --------------- +# if self.reg_decoded_bbox: +# # When the regression loss (e.g. `IouLoss`, `GIouLoss`) +# # is applied directly on the decoded bounding boxes, it +# # decodes the already encoded coordinates to absolute format. +# anchors = anchors.reshape(-1, anchors.size(-1)) +# d_bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) +# d_bbox_pred = get_box_tensor(d_bbox_pred) +# loss_bbox = self.loss_bbox( +# d_bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) +# return loss_cls, loss_bbox +# +# def predict_by_feat(self, +# cls_scores: List[Tensor], +# bbox_preds: List[Tensor], +# score_factors: Optional[List[Tensor]] = None, +# batch_img_metas: Optional[List[dict]] = None, +# cfg: Optional[ConfigDict] = None, +# rescale: bool = True, # ------------------------------ +# with_nms: bool = False) -> InstanceList: # ----------------------- +# assert len(cls_scores) == len(bbox_preds) +# +# if score_factors is None: +# # e.g. Retina, FreeAnchor, Foveabox, etc. +# with_score_factors = False +# else: +# # e.g. FCOS, PAA, ATSS, AutoAssign, etc. +# with_score_factors = True +# assert len(cls_scores) == len(score_factors) +# +# num_levels = len(cls_scores) +# +# featmap_sizes = [cls_scores[i].shape[-1] for i in range(num_levels)] # ------------------ +# mlvl_priors = self.prior_generator.grid_priors( +# featmap_sizes, +# dtype=cls_scores[0].dtype, +# device=cls_scores[0].device) +# +# result_list = [] +# +# for img_id in range(len(batch_img_metas)): +# img_meta = batch_img_metas[img_id] +# cls_score_list = select_single_mlvl( +# cls_scores, img_id, detach=True) +# bbox_pred_list = select_single_mlvl( +# bbox_preds, img_id, detach=True) +# if with_score_factors: +# score_factor_list = select_single_mlvl( +# score_factors, img_id, detach=True) +# else: +# score_factor_list = [None for _ in range(num_levels)] +# +# results = self._predict_by_feat_single( +# cls_score_list=cls_score_list, +# bbox_pred_list=bbox_pred_list, +# score_factor_list=score_factor_list, +# mlvl_priors=mlvl_priors, +# img_meta=img_meta, +# cfg=cfg, +# rescale=rescale, +# with_nms=with_nms) +# result_list.append(results) +# return result_list +# +# def _predict_by_feat_single(self, +# cls_score_list: List[Tensor], +# bbox_pred_list: List[Tensor], +# score_factor_list: List[Tensor], +# mlvl_priors: List[Tensor], +# img_meta: dict, +# cfg: ConfigDict, +# rescale: bool = True, # ------------------ +# with_nms: bool = False) -> InstanceData: # -------------------------- +# if score_factor_list[0] is None: +# # e.g. Retina, FreeAnchor, etc. +# with_score_factors = False +# else: +# # e.g. FCOS, PAA, ATSS, etc. +# with_score_factors = True +# +# cfg = self.test_cfg if cfg is None else cfg +# cfg = copy.deepcopy(cfg) +# tsize = img_meta['tsize'] # --------------- +# nms_pre = cfg.get('nms_pre', -1) +# +# mlvl_bbox_preds = [] +# mlvl_valid_priors = [] +# mlvl_scores = [] +# mlvl_labels = [] +# if with_score_factors: +# mlvl_score_factors = [] +# else: +# mlvl_score_factors = None +# for level_idx, (cls_score, bbox_pred, score_factor, priors) in \ +# enumerate(zip(cls_score_list, bbox_pred_list, +# score_factor_list, mlvl_priors)): +# +# assert cls_score.size()[-1] == bbox_pred.size()[-1] # -------------- +# +# dim = self.bbox_coder.encode_size +# bbox_pred = bbox_pred.permute(1, 0).reshape(-1, dim) # -------------- +# if with_score_factors: +# score_factor = score_factor.permute(1, 0).reshape(-1).sigmoid() # -------------- +# cls_score = cls_score.permute(1, 0).reshape(-1, self.cls_out_channels) # ---------------- +# if self.use_sigmoid_cls: +# scores = cls_score.sigmoid() +# else: +# # remind that we set FG labels to [0, num_class-1] +# # since mmdet v2.0 +# # BG cat_id: num_class +# scores = cls_score.softmax(-1)[:, :-1] +# +# # After https://github.com/open-mmlab/mmdetection/pull/6268/, +# # this operation keeps fewer bboxes under the same `nms_pre`. +# # There is no difference in performance for most models. If you +# # find a slight drop in performance, you can set a larger +# # `nms_pre` than before. +# score_thr = cfg.get('score_thr', 0) +# +# results = filter_scores_and_topk( +# scores, score_thr, nms_pre, +# dict(bbox_pred=bbox_pred, priors=priors)) +# scores, labels, keep_idxs, filtered_results = results +# +# bbox_pred = filtered_results['bbox_pred'] +# priors = filtered_results['priors'] +# +# if with_score_factors: +# score_factor = score_factor[keep_idxs] +# +# mlvl_bbox_preds.append(bbox_pred) +# mlvl_valid_priors.append(priors) +# mlvl_scores.append(scores) +# mlvl_labels.append(labels) +# +# if with_score_factors: +# mlvl_score_factors.append(score_factor) +# +# bbox_pred = torch.cat(mlvl_bbox_preds) +# priors = cat_boxes(mlvl_valid_priors) +# bboxes = self.bbox_coder.decode(priors, bbox_pred, max_t=tsize) +# +# results = InstanceData() +# results.bboxes = bboxes +# results.scores = torch.cat(mlvl_scores) +# results.labels = torch.cat(mlvl_labels) +# if with_score_factors: +# results.score_factors = torch.cat(mlvl_score_factors) +# results = self._bbox_post_process( +# results=results, +# cfg=cfg, +# rescale=rescale, +# with_nms=with_nms, +# img_meta=img_meta) +# return results +# +# def _bbox_post_process(self, +# results: InstanceData, +# cfg: ConfigDict, +# rescale: bool = True, # -------------------------- +# with_nms: bool = False, # ---------------------------- +# img_meta: Optional[dict] = None) -> InstanceData: +# if rescale: +# assert img_meta.get('tscale_factor') is not None +# tscale_factor = 1 / img_meta['tscale_factor'] # ------------ +# results.bboxes = scale_boxes(results.bboxes, tscale_factor) +# # Convert the bboxes co-ordinate from the input video segment to the original video +# results.bboxes += img_meta.get('tshift', 0) # --------------- +# +# if hasattr(results, 'score_factors'): +# # TODO: Add sqrt operation in order to be consistent with +# # the paper. +# score_factors = results.pop('score_factors') +# results.scores = results.scores * score_factors +# +# # filter small size bboxes +# if cfg.get('min_bbox_size', 0) >= 0: +# l = get_segment_len(results.bboxes) # --------------- +# valid_mask = l > cfg.get('min_bbox_size', 0) # --------------- +# if not valid_mask.all(): +# results = results[valid_mask] +# +# # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg +# if with_nms and results.bboxes.numel() > 0: +# bboxes = get_box_tensor(results.bboxes) +# det_bboxes, keep_idxs = batched_nms1d(bboxes, results.scores, +# results.labels, +# cfg.get('nms', dict(type='nms', iou_thr=0.5))) # ------------------ +# results = results[keep_idxs] +# # some nms would reweight the score, such as softnms +# results.scores = det_bboxes[:, -1] +# results = results[:cfg.get('max_per_video', 100)] # ------------------ +# +# return results +# +# def loss_and_predict( +# self, +# x: Tuple[Tensor], +# batch_data_samples: SampleList, +# with_nms: bool = False, # ------------------ +# rescale: bool = True) -> Tuple[dict, InstanceList]: # ------------------ +# +# outputs = unpack_gt_instances(batch_data_samples) +# (batch_gt_instances, batch_gt_instances_ignore, +# batch_img_metas) = outputs +# +# outs = self(x) +# +# loss_inputs = outs + (batch_gt_instances, batch_img_metas, +# batch_gt_instances_ignore) +# losses = self.loss_by_feat(*loss_inputs) +# +# predictions = self.predict_by_feat( +# *outs, with_nms=with_nms, batch_img_metas=batch_img_metas, rescale=rescale) +# return losses, predictions +# +# +# def get_segment_len(segments: Union[Tensor, BaseBoxes]) -> Tuple[Tensor, Tensor]: +# """Get the width and height of boxes with type of tensor or box type. +# +# Args: +# segments (Tensor or :obj:`BaseBoxes`): boxes with type of tensor +# or box type. +# +# Returns: +# Tuple[Tensor, Tensor]: the width and height of boxes. +# """ +# if isinstance(segments, BaseBoxes): +# l = segments.length +# else: +# # Tensor boxes will be treated as horizontal boxes by defaults +# l = segments[:, 1] - segments[:, 0] +# return l diff --git a/projects/basic_tad/models/models/losses/__init__.py b/projects/basic_tad/models/models/losses/__init__.py new file mode 100644 index 0000000000..0e727df9b5 --- /dev/null +++ b/projects/basic_tad/models/models/losses/__init__.py @@ -0,0 +1 @@ +from .diou_1d_loss import DIoU1DLoss diff --git a/projects/basic_tad/models/models/losses/diou_1d_loss.py b/projects/basic_tad/models/models/losses/diou_1d_loss.py new file mode 100644 index 0000000000..6e4391669f --- /dev/null +++ b/projects/basic_tad/models/models/losses/diou_1d_loss.py @@ -0,0 +1,22 @@ +# adapted from basicTAD +from mmdet.models.losses import DIoULoss +from torch import Tensor + +from mmaction.registry import MODELS + + +def zero_out_loss_coordinates_decorator(forward_method): + def wrapper(self, pred: Tensor, target: Tensor, *args, **kwargs): + pred = pred.clone() + pred[:, 1] = pred[:, 1] * 0 + target[:, 1] + pred[:, 3] = pred[:, 3] * 0 + target[:, 3] + return forward_method(self, pred, target, *args, **kwargs) + + return wrapper + + +@MODELS.register_module() +class DIoU1DLoss(DIoULoss): + @zero_out_loss_coordinates_decorator + def forward(self, pred: Tensor, target: Tensor, *args, **kwargs) -> Tensor: + return super().forward(pred, target, *args, **kwargs) diff --git a/projects/basic_tad/models/models/necks/__init__.py b/projects/basic_tad/models/models/necks/__init__.py new file mode 100644 index 0000000000..5c653924f3 --- /dev/null +++ b/projects/basic_tad/models/models/necks/__init__.py @@ -0,0 +1 @@ +from .vdm import VDM diff --git a/projects/basic_tad/models/models/necks/vdm.py b/projects/basic_tad/models/models/necks/vdm.py new file mode 100644 index 0000000000..9f9a094b8e --- /dev/null +++ b/projects/basic_tad/models/models/necks/vdm.py @@ -0,0 +1,135 @@ +import torch.nn as nn +from mmaction.registry import MODELS +from mmcv.cnn import ConvModule +from mmengine.model import kaiming_init, constant_init +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.conv import _ConvNd +from torch.nn import MaxPool3d + +MODELS.register_module(name='MaxPool3d', module=MaxPool3d) + +# x: b,c,t,h,w + + +def n_tuple(x, num): + return [x for i in range(num)] + + +@MODELS.register_module() +class VDM(nn.Module): + """Temporal Down-Sampling Module.""" + + def __init__(self, + in_channels=2048, + stage_layers=(1, 1, 1, 1), + kernel_sizes=(3, 1, 1), + strides=(2, 1, 1), + paddings=(1, 0, 0), + dilations=(1, 1, 1), + out_channels=512, + conv_cfg=dict(type='Conv3d'), + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='ReLU'), + out_indices=(0, 1, 2, 3, 4), + out_pooling=True, + ): + super(VDM, self).__init__() + + self.in_channels = in_channels + self.num_stages = len(stage_layers) + self.stage_layers = stage_layers + self.kernel_sizes = n_tuple(kernel_sizes, self.num_stages) + self.strides = n_tuple(strides, self.num_stages) + self.paddings = n_tuple(paddings, self.num_stages) + self.dilations = n_tuple(dilations, self.num_stages) + self.out_channels = n_tuple(out_channels, self.num_stages) + self.out_pooling = out_pooling + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.out_indices = out_indices + + assert (len(self.stage_layers) == len(self.kernel_sizes) == len( + self.strides) == len(self.paddings) == len(self.dilations) == len( + self.out_channels)) + + self.td_layers = [] + for i in range(self.num_stages): + td_layer = self.make_td_layer(self.stage_layers[i], in_channels, + self.out_channels[i], + self.kernel_sizes[i], + self.strides[i], self.paddings[i], + self.dilations[i], self.conv_cfg, + self.norm_cfg, self.act_cfg) + in_channels = self.out_channels[i] + layer_name = f'layer{i + 1}' + self.add_module(layer_name, td_layer) + self.td_layers.append(layer_name) + + self.spatial_pooling = nn.AdaptiveAvgPool3d((None, 1, 1)) + + def sp(self, x): + return self.spatial_pooling(x).squeeze(-1).squeeze(-1) + + @staticmethod + def make_td_layer(num_layer, in_channels, out_channels, kernel_size, + stride, padding, dilation, conv_cfg, norm_cfg, act_cfg): + layers = [] + layers.append( + ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(1, num_layer): + layers.append( + ConvModule( + out_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + return nn.Sequential(*layers) + + def init_weights(self): + """Initiate the parameters.""" + for m in self.modules(): + if isinstance(m, _ConvNd): + kaiming_init(m) + elif isinstance(m, _BatchNorm): + constant_init(m, 1) + + def train(self, mode=True): + """Set the optimization status when training.""" + super().train(mode) + + if mode: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + + def forward(self, x): + outs = [] + if 0 in self.out_indices: + outs.append(x) + + for i, layer_name in enumerate(self.td_layers): + layer = getattr(self, layer_name) + x = layer(x) + if (i + 1) in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + + if self.out_pooling: + for i in range(len(outs)): + outs[i] = self.sp(outs[i]) + + return tuple(outs) diff --git a/projects/basic_tad/models/models/task_modules/__init__.py b/projects/basic_tad/models/models/task_modules/__init__.py new file mode 100644 index 0000000000..0597af01d4 --- /dev/null +++ b/projects/basic_tad/models/models/task_modules/__init__.py @@ -0,0 +1 @@ +from .iou1d_calculator import BboxOverlaps1D diff --git a/projects/basic_tad/models/models/task_modules/iou1d_calculator.py b/projects/basic_tad/models/models/task_modules/iou1d_calculator.py new file mode 100644 index 0000000000..1889b9de4b --- /dev/null +++ b/projects/basic_tad/models/models/task_modules/iou1d_calculator.py @@ -0,0 +1,16 @@ +from mmdet.models.task_modules.assigners.iou2d_calculator import BboxOverlaps2D +from mmaction.registry import TASK_UTILS +from mmdet.structures.bbox import get_box_tensor + + +@TASK_UTILS.register_module() +class BboxOverlaps1D(BboxOverlaps2D): + """IoU Calculator that ignore the y1 and y2.""" + + def __call__(self, bboxes1, bboxes2, *args, **kwargs): + bboxes1, bboxes2 = get_box_tensor(bboxes1), get_box_tensor(bboxes2) + bboxes1[:, 1] = 0.1 + bboxes2[:, 1] = 0.1 + bboxes1[:, 3] = 0.9 + bboxes2[:, 3] = 0.9 + return super().__call__(bboxes1, bboxes2, *args, **kwargs) diff --git a/projects/basic_tad/models/models/task_modules/segments_ops.py b/projects/basic_tad/models/models/task_modules/segments_ops.py new file mode 100644 index 0000000000..d588ce79e5 --- /dev/null +++ b/projects/basic_tad/models/models/task_modules/segments_ops.py @@ -0,0 +1,287 @@ +import numpy as np +import torch +from mmcv.ops import batched_nms + + +def convert_1d_to_2d_bboxes(bboxes_1d, fixed_dim_value=0): + """ + Convert 1D bounding boxes to pseudo 2D bounding boxes by adding a fixed dimension. + + Args: + bboxes_1d (torch.Tensor): 1D bounding boxes tensor of shape (N, 2) + fixed_dim_value (float): Value to set for the fixed dimension in the 2D bounding boxes + + Returns: + torch.Tensor: Pseudo 2D bounding boxes tensor of shape (N, 4) + """ + # Get the number of bounding boxes + num_bboxes = bboxes_1d.shape[0] + + # Initialize the 2D bounding boxes tensor + bboxes_2d = torch.zeros((num_bboxes, 4), device=bboxes_1d.device, dtype=bboxes_1d.dtype) + + # Set the fixed dimension value for ymin and ymax + bboxes_2d[:, 1] = fixed_dim_value + bboxes_2d[:, 3] = fixed_dim_value + 1 + + # Copy the 1D intervals (xmin and xmax) to the 2D bounding boxes + bboxes_2d[:, 0::2] = bboxes_1d + + return bboxes_2d + + +def convert_2d_to_1d_bboxes(bboxes_2d): + """ + Convert pseudo 2D bounding boxes back to 1D bounding boxes by extracting xmin and xmax. + + Args: + bboxes_2d (torch.Tensor): Pseudo 2D bounding boxes tensor of shape (N, 4) + + Returns: + torch.Tensor: 1D bounding boxes tensor of shape (N, 2) + """ + # Extract xmin and xmax (first and third columns) from the 2D bounding boxes + bboxes_1d = bboxes_2d[:, 0::2] + + return bboxes_1d + + +def batched_nms1d(bboxes_1d, *args, **kwargs): + """ + Apply Non-Maximum Suppression (NMS) on 1D bounding boxes by converting them to pseudo 2D bounding boxes, + using a 2D NMS function, and converting the results back to 1D bounding boxes. + + Args: + bboxes_1d (torch.Tensor): 1D bounding boxes tensor of shape (N, 2) + *args: Additional arguments to pass to the batched_nms function + **kwargs: Additional keyword arguments to pass to the batched_nms function + + Returns: + tuple: A tuple containing: + - torch.Tensor: The 1D bounding boxes after NMS, with shape (N', 2) + - torch.Tensor: The indices of the kept bounding boxes, with shape (N',) + """ + bboxes_2d = convert_1d_to_2d_bboxes(bboxes_1d) + boxes, keep = batched_nms(bboxes_2d, *args, **kwargs) + return convert_2d_to_1d_bboxes(boxes), keep + + +# Below all adapted from basicTAD +def segment_overlaps(segments1, + segments2, + mode='iou', + is_aligned=False, + eps=1e-6, + detect_overlap_edge=False): + """Calculate overlap between two set of segments. + If ``is_aligned`` is ``False``, then calculate the ious between each + segment of segments1 and segments2, otherwise the ious between each aligned + pair of segments1 and segments2. + Args: + segments1 (Tensor): shape (m, 2) in format or empty. + segments2 (Tensor): shape (n, 2) in format or empty. + If is_aligned is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union) or iof (intersection over + foreground). + Returns: + ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1) + Example: + >>> segments1 = torch.FloatTensor([ + >>> [0, 10], + >>> [10, 20], + >>> [32, 38], + >>> ]) + >>> segments2 = torch.FloatTensor([ + >>> [0, 20], + >>> [0, 19], + >>> [10, 20], + >>> ]) + >>> segment_overlaps(segments1, segments2) + tensor([[0.5000, 0.5263, 0.0000], + [0.0000, 0.4500, 1.0000], + [0.0000, 0.0000, 0.0000]]) + Example: + >>> empty = torch.FloatTensor([]) + >>> nonempty = torch.FloatTensor([ + >>> [0, 9], + >>> ]) + >>> assert tuple(segment_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(segment_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(segment_overlaps(empty, empty).shape) == (0, 0) + """ + + is_numpy = False + if isinstance(segments1, np.ndarray): + segments1 = torch.from_numpy(segments1) + is_numpy = True + if isinstance(segments2, np.ndarray): + segments2 = torch.from_numpy(segments2) + is_numpy = True + + segments1, segments2 = segments1.float(), segments2.float() + + assert mode in ['iou', 'iof'] + # Either the segments are empty or the length of segments's last dimenstion + # is 2 + assert (segments1.size(-1) == 2 or segments1.size(0) == 0) + assert (segments2.size(-1) == 2 or segments2.size(0) == 0) + + rows = segments1.size(0) + cols = segments2.size(0) + if is_aligned: + assert rows == cols + + if rows * cols == 0: + return segments1.new(rows, 1) if is_aligned else segments2.new( + rows, cols) + + if is_aligned: + start = torch.max(segments1[:, 0], segments2[:, 0]) # [rows] + end = torch.min(segments1[:, 1], segments2[:, 1]) # [rows] + + overlap = end - start + if detect_overlap_edge: + overlap[overlap == 0] += eps + overlap = overlap.clamp(min=0) # [rows, 2] + area1 = segments1[:, 1] - segments1[:, 0] + + if mode == 'iou': + area2 = segments2[:, 1] - segments2[:, 0] + union = area1 + area2 - overlap + else: + union = area1 + else: + start = torch.max(segments1[:, None, 0], segments2[:, + 0]) # [rows, cols] + end = torch.min(segments1[:, None, 1], segments2[:, 1]) # [rows, cols] + + overlap = end - start + if detect_overlap_edge: + overlap[overlap == 0] += eps + overlap = overlap.clamp(min=0) # [rows, 2] + area1 = segments1[:, 1] - segments1[:, 0] + + if mode == 'iou': + area2 = segments2[:, 1] - segments2[:, 0] + union = area1[:, None] + area2 - overlap + else: + union = area1[:, None] + + eps = union.new_tensor([eps]) + union = torch.max(union, eps) + ious = overlap / union + + if is_numpy: + ious = ious.numpy() + + return ious + + +def batched_nmw(bboxes, + scores, + labels, + nms_cfg): + """Non-Maximum Weighting for multi-class segments. + + Args: + multi_segments (Tensor): shape (n, #class*2) or (n, 2) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): segment threshold, segments with scores lower than + it will not be considered. + nms_cfg (dict): NMS cfg. + max_num (int): if there are more than max_num segments after NMS, + only top max_num will be kept. + score_factors (Tensor): The factors multiplied to scores before + applying NMS + + Returns: + tuple: (segments, labels), tensors of shape (k, 3) and (k, 1). Labels + are 0-based. + """ + + def _nmw1d(bboxes, scores, labels, remainings): + from mmdet.structures.bbox import bbox_overlaps + # select the best prediction (with the highest score) from the remaining. + keep_idx = remainings[0] + # collect predictions that have the same labels with the best. + mask = labels[remainings] == labels[keep_idx] + bboxes = bboxes[remainings][mask] + scores = scores[remainings][mask] + labels = labels[remainings][mask] + + # NMS output the best prediction and delete predictions that intersect with it with IoU >= iou_thr, + # While NMW aggregates the best prediction and the predictions that intersect with it with IoU >= iou_thr and + # outputs the aggregated prediction. The aggregation is based on the scores of the predictions. + ious = bbox_overlaps(bboxes[:1], bboxes, mode='iou')[0] + ious[0] = 1.0 + iou_mask = ious >= iou_thr + aggregate_bboxes = bboxes[iou_mask] + accu_weights = scores[iou_mask] * ious[iou_mask] + accu_weights /= accu_weights.sum() + bbox = (accu_weights[:, None] * aggregate_bboxes).sum(dim=0) + score = scores[0] + label = labels[0] + + # delete the aggregated predictions from the remaining. + inds = torch.nonzero(mask)[:, 0] + mask[inds[~iou_mask]] = False + remainings = remainings[~mask] + + return bbox, score, label, remainings, keep_idx + + score_factors = nms_cfg.pop('score_factor', None) + score_thr = nms_cfg.pop('score_threshold', 0) + # skip nms when nms_cfg is None + if nms_cfg is None: + scores, inds = scores.sort(descending=True) + bboxes = bboxes[inds] + return torch.cat([bboxes, scores[:, None]], -1), inds + + # num_classes = scores.size(1) + # # exclude background category + # if bboxes.shape[1] > 2: + # bboxes = bboxes.view(scores.size(0), -1, 2) + # else: + # bboxes = bboxes[:, None].expand(-1, num_classes, 2) + + # filter out segments with low scores + if score_factors is not None: + scores = scores * score_factors[:, None] + valid_mask = scores > score_thr + bboxes = bboxes[valid_mask] + scores = scores[valid_mask] + # labels = valid_mask.nonzero(as_tuple=False)[:, 1] + labels = labels[valid_mask] + + if bboxes.numel() == 0: + bboxes = bboxes.new_zeros((0, 3)) + labels = bboxes.new_zeros((0,), dtype=torch.long) + return bboxes, labels + + remainings = scores.argsort(descending=True) + + max_num = nms_cfg.get('max_num', -1) + iou_thr = nms_cfg.get('iou_thr') + + results = [] + while remainings.numel() > 0: + bbox, score, label, remainings, keep_idx = _nmw1d(bboxes, scores, labels, remainings) + results.append([bbox, score, label, keep_idx]) + if max_num > 0 and len(results) == max_num: + break + + if len(results) == 0: + bboxes = bboxes.new_zeros((0, 2)) + scores = scores.new_zeros((0,)) + labels = labels.new_zeros((0,)) + keep = labels.new_zeros((0,)) + else: + bboxes, scores, labels, keep = list(zip(*results)) + bboxes = torch.stack(bboxes) + scores = torch.stack(scores) + labels = torch.stack(labels) + keep = torch.stack(keep) + dets = torch.cat([bboxes, scores[:, None]], dim=-1) + + return dets, keep