diff --git a/configs/localization/tcanet/README.md b/configs/localization/tcanet/README.md new file mode 100644 index 0000000000..e04f154c4a --- /dev/null +++ b/configs/localization/tcanet/README.md @@ -0,0 +1,66 @@ +# TCANet + +[Temporal Context Aggregation Network for Temporal Action Proposal Refinement](https://openaccess.thecvf.com/content/CVPR2021/papers/Qing_Temporal_Context_Aggregation_Network_for_Temporal_Action_Proposal_Refinement_CVPR_2021_paper.pdf) + + + +## Abstract + + + +Temporal action proposal generation aims to estimate temporal intervals of actions in untrimmed videos, which is a challenging yet important task in the video understanding field. +The proposals generated by current methods still suffer from inaccurate temporal boundaries and inferior confidence used for retrieval owing to the lack of efficient temporal modeling and effective boundary context utilization. +In this paper, we propose Temporal Context Aggregation Network (TCANet) to generate high-quality action proposals through `local and global` temporal context aggregation and complementary as well as progressive boundary refinement. +Specifically, we first design a Local-Global Temporal Encoder (LGTE), which adopts the channel grouping strategy to efficiently encode both `local and global` temporal inter-dependencies. +Furthermore, both the boundary and internal context of proposals are adopted for frame-level and segment-level boundary regressions, respectively. +Temporal Boundary Regressor (TBR) is designed to combine these two regression granularities in an end-to-end fashion, which achieves the precise boundaries and reliable confidence of proposals through progressive refinement. Extensive experiments are conducted on three challenging datasets: HACS, ActivityNet-v1.3, and THUMOS-14, where TCANet can generate proposals with high precision and recall. By combining with the existing action classifier, TCANet can obtain remarkable temporal action detection performance compared with other methods. Not surprisingly, the proposed TCANet won the 1$^{st}$ place in the CVPR 2020 - HACS challenge leaderboard on temporal action localization task. + + + +
+ +
+ +## Results and Models + +### HACS dataset + +| feature | gpus | pretrain | AUC | AR@1 | AR@5 | AR@10 | AR@100 | gpu_mem(M) | iter time(s) | config | ckpt | log | +| :------: | :--: | :------: | :---: | :---: | :---: | :---: | :----: | :--------: | :----------: | :-------------------------------------------: | :------------------------------------------: | :-----------------------------------------: | +| SlowOnly | 2 | None | 68.33 | 32.89 | 49.43 | 56.64 | 75.29 | 5412 | - | [config](/configs/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature_20230619-95fd88b0.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.log) | + +For more details on data preparation, you can refer to [HACS Data Preparation](/tools/data/hacs/README.md). + +## Train + +Train TCANet model on HACS dataset with the SlowOnly feature. + +```shell +bash tools/dist_train.sh configs/localization/tcanet/tcanet_2048x100_2x8_9e_hacs_feature.py 2 +``` + +For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Test + +Test TCANet model on HACS dataset with the SlowOnly feature. + +```shell +python3 tools/test.py configs/localization/tcanet/tcanet_2048x100_2x8_9e_hacs_feature.py CHECKPOINT.PTH +``` + +For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md). + +## Citation + + + +```BibTeX +@inproceedings{qing2021temporal, + title={Temporal Context Aggregation Network for Temporal Action Proposal Refinement}, + author={Qing, Zhiwu and Su, Haisheng and Gan, Weihao and Wang, Dongliang and Wu, Wei and Wang, Xiang and Qiao, Yu and Yan, Junjie and Gao, Changxin and Sang, Nong}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={485--494}, + year={2021} +} +``` diff --git a/configs/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.py b/configs/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.py new file mode 100644 index 0000000000..ff986dc045 --- /dev/null +++ b/configs/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.py @@ -0,0 +1,121 @@ +_base_ = '../../_base_/default_runtime.py' + +# model settings +model = dict( + type='TCANet', + feat_dim=2048, + se_sample_num=32, + action_sample_num=64, + temporal_dim=100, + window_size=9, + lgte_num=2, + soft_nms_alpha=0.4, + soft_nms_low_threshold=0.0, + soft_nms_high_threshold=0.0, + post_process_top_k=100, + feature_extraction_interval=16) + +# dataset settings +dataset_type = 'ActivityNetDataset' +data_root = 'data/HACS/slowonly_feature/' +data_root_val = 'data/HACS/slowonly_feature/' +ann_file_train = 'data/HACS/hacs_anno_train.json' +ann_file_val = 'data/HACS/hacs_anno_val.json' +ann_file_test = 'data/HACS/hacs_anno_val.json' + +train_pipeline = [ + dict(type='LoadLocalizationFeature'), + dict(type='GenerateLocalizationLabels'), + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', ), + meta_keys=('video_name', )) +] + +val_pipeline = [ + dict(type='LoadLocalizationFeature'), + dict(type='GenerateLocalizationLabels'), + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', ), + meta_keys=('video_name', 'duration_second', 'duration_frame', + 'annotations', 'feature_frame')) +] + +test_pipeline = [ + dict(type='LoadLocalizationFeature'), + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', ), + meta_keys=('video_name', 'duration_second', 'duration_frame', + 'annotations', 'feature_frame')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 9 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=1, + val_interval=1) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict( + optimizer=dict(type='Adam', lr=0.001, weight_decay=0.0001), + clip_grad=dict(max_norm=40, norm_type=2)) + +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 7, + ], + gamma=0.1) +] + +work_dir = './work_dirs/tcanet_2xb8-2048x100-9e_hacs-feature/' +test_evaluator = dict( + type='ANetMetric', + metric_type='AR@AN', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator diff --git a/mmaction/models/localizers/bsn.py b/mmaction/models/localizers/bsn.py index c35e3d9bca..a2b61a37cb 100644 --- a/mmaction/models/localizers/bsn.py +++ b/mmaction/models/localizers/bsn.py @@ -19,6 +19,7 @@ class TEM(BaseModel): Code reference https://github.com/wzmsltw/BSN-boundary-sensitive-network Args: + temporal_dim (int): Total frames selected for each video. tem_feat_dim (int): Feature dimension. tem_hidden_dim (int): Hidden layer dimension. tem_match_threshold (float): Temporal evaluation match threshold. diff --git a/mmaction/models/localizers/tcanet.py b/mmaction/models/localizers/tcanet.py new file mode 100644 index 0000000000..a606407a6a --- /dev/null +++ b/mmaction/models/localizers/tcanet.py @@ -0,0 +1,477 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.model import BaseModel +from torch import Tensor, nn +from utils import (batch_iou, bbox_se_transform_batch, bbox_se_transform_inv, + bbox_xw_transform_batch, bbox_xw_transform_inv, + post_processing) + +from mmaction.registry import MODELS +from mmaction.utils import OptConfigType + + +class LGTE(BaseModel): + """Local-Global Temporal Encoder (LGTE) + + Args: + input_dim (int): Input feature dimension. + dropout (float): the dropout rate for the residual branch of + self-attention and ffn. + temporal_dim (int): Total frames selected for each video. + Defaults to 100. + window_size (int): the window size for Local Temporal Encoder. + Defaults to 9. + init_cfg (dict or ConfigDict, optional): The Config for + initialization. Defaults to None. + """ + + def __init__(self, + input_dim: int, + dropout: float, + temporal_dim: int = 100, + window_size: int = 9, + num_heads: int = 8, + init_cfg: OptConfigType = None, + **kwargs) -> None: + super(LGTE, self).__init__(init_cfg) + + self.atten = MultiheadAttention( + embed_dims=input_dim, + num_heads=num_heads, + proj_drop=dropout, + attn_drop=0.1) + self.ffn = FFN( + embed_dims=input_dim, feedforward_channels=256, ffn_drop=dropout) + + norm_cfg = dict(type='LN', eps=1e-6) + self.norm1 = build_norm_layer(norm_cfg, input_dim)[1] + self.norm1 = build_norm_layer(norm_cfg, input_dim)[1] + + mask = self._mask_matrix(num_heads, temporal_dim, window_size) + self.register_buffer('mask', mask) + + def forward(self, x: Tensor) -> Tensor: + """Forward call for LGTE. + + Args: + x (torch.Tensor): The input tensor with shape (B, T, C) + """ + x = x.permute(2, 0, 1) + mask = self.mask.repeat(x.size(1), 1, 1, 1) + x = self.atten(x, mask) + x = self.norm1(x) + x = self.ffn(x) + x = self.norm2(x) + x = x.permute(1, 2, 0) + return x + + @staticmethod + def _mask_matrix(num_heads: int, temporal_dim: int, + window_size: int) -> Tensor: + mask = torch.zeros(num_heads, temporal_dim, temporal_dim) + index = torch.arange(temporal_dim) + + for i in range(num_heads // 2): + for j in range(temporal_dim): + ignored = (index - j).abs() > window_size / 2 + mask[i, j] = ignored + + return mask.unsqueeze(0).bool() + + +def StartEndRegressor(sample_num: int, feat_dim: int) -> nn.Module: + """Start and End Regressor in the Temporal Boundary Regressor. + + Args: + sample_num (int): number of samples for the start & end. + feat_dim (int): feature dimension. + + Returns: + A pytorch module that works as the start and end regressor. The input + of the module should have a shape of (B, feat_dim * 2, sample_num). + """ + hidden_dim = 128 + regressor = nn.Sequential( + nn.Conv1d( + feat_dim * 2, + hidden_dim * 2, + kernel_size=3, + padding=1, + groups=8, + stride=2), nn.ReLU(inplace=True), + nn.Conv1d( + hidden_dim * 2, + hidden_dim * 2, + kernel_size=3, + padding=1, + groups=8, + stride=2), nn.ReLU(inplace=True), + nn.Conv1d(hidden_dim * 2, 2, kernel_size=sample_num // 4, groups=2), + nn.Flatten()) + return regressor + + +def CenterWidthRegressor(temporal_len: int, feat_dim: int) -> nn.Module: + """Center Width in the Temporal Boundary Regressor. + + Args: + temporal_len (int): temporal dimension of the inputs. + feat_dim (int): feature dimension. + + Returns: + A pytorch module that works as the start and end regressor. The input + of the module should have a shape of (B, feat_dim, temporal_len). + """ + hidden_dim = 512 + regressor = nn.Sequential( + nn.Conv1d( + feat_dim, hidden_dim, kernel_size=3, padding=1, groups=4, + stride=2), nn.ReLU(inplace=True), + nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=1, + groups=4, + stride=2), nn.ReLU(inplace=True), + nn.Conv1d( + hidden_dim, hidden_dim, kernel_size=temporal_len // 4, groups=4), + nn.ReLU(inplace=True), nn.Conv1d(hidden_dim, 3, kernel_size=1)) + return regressor + + +class TemporalTransform: + """Temporal Transform to sample temporal features.""" + + def __init__(self, prop_boundary_ratio: float, action_sample_num: int, + se_sample_num: int, temporal_interval: int): + super(TemporalTransform, self).__init__() + self.temporal_interval = temporal_interval + self.prop_boundary_ratio = prop_boundary_ratio + self.action_sample_num = action_sample_num + self.se_sample_num = se_sample_num + + def forward(self, segments: Tensor, features: Tensor) -> List[Tensor]: + s_len = segments[:, 1] - segments[:, 0] + starts_segments = [ + segments[:, 0] - self.prop_boundary_ratio * s_len, segments[:, 0] + ] + starts_segments = torch.stack(starts_segments, dim=1) + + ends_segments = [ + segments[:, 1], segments[:, 1] + self.prop_boundary_ratio * s_len + ] + ends_segments = torch.stack(ends_segments, dim=1) + + starts_feature = self._sample_one_temporal(starts_segments, + self.se_sample_num, + features) + ends_feature = self._sample_one_temporal(ends_segments, + self.se_sample_num, features) + actions_feature = self._sample_one_temporal(segments, + self.action_sample_num, + features) + return starts_feature, actions_feature, ends_feature + + def _sample_one_temporal(self, segments: Tensor, out_len: int, + features: Tensor) -> Tensor: + total_temporal_len = features.size(2) * self.temporal_interval + segments = torch.clamp(segments / total_temporal_len, max=1., min=0.) + segments = segments * 2 - 1 + + theta = segments.new_zeros((features.size(0), 2, 3)) + theta[:, 1, 1] = 1.0 + theta[:, 0, 0] = (segments[:, 1] - segments[:, 0]) / 2.0 + theta[:, 0, 2] = (segments[:, 1] + segments[:, 0]) / 2.0 + + size = torch.Size((*features.shape[:2], 1, out_len)) + grid = F.affine_grid(theta, size) + stn_feature = F.grid_sample(features.unsqueeze(2), grid) + stn_feature = stn_feature.view(*features.shape[:2], out_len) + return stn_feature + + +class TBR(BaseModel): + """Temporal Boundary Regressor (TBR)""" + + def __init__(self, + se_sample_num: int, + action_sample_num: int, + temporal_dim: int, + prop_boundary_ratio: float = 0.5, + init_cfg: OptConfigType = None, + **kwargs) -> None: + super(TBR, self).__init__(init_cfg) + + hidden_dim = 512 + + self.reg1se = StartEndRegressor(se_sample_num, hidden_dim) + temporal_len = se_sample_num * 2 + action_sample_num + self.reg1xw = CenterWidthRegressor(temporal_len, hidden_dim) + self.ttn = TemporalTransform(prop_boundary_ratio, action_sample_num, + se_sample_num, temporal_dim) + + def forward(self, proposals: Tensor, features: Tensor, gt_boxes: Tensor, + iou_thres: float, training: bool) -> tuple: + proposals1 = proposals[:, :2] + starts_feat1, actions_feat1, ends_feat1 = self.ttn( + proposals1, features) + reg1se = self.reg1se(starts_feat1, ends_feat1) + features1xw = torch.cat([starts_feat1, actions_feat1, ends_feat1], + dim=2) + reg1xw = self.reg1xw(features1xw).squeeze(2) + preds_iou1 = reg1xw[:, 2].sigmoid() + reg1xw = reg1xw[:, :2] + + if training: + proposals2xw = bbox_xw_transform_inv(proposals1, reg1xw, 0.1, 0.2) + proposals2se = bbox_se_transform_inv(proposals1, reg1se, 1.0) + + iou1 = batch_iou(proposals1, gt_boxes) + targets1se = bbox_se_transform_batch(proposals1, gt_boxes) + targets1xw = bbox_xw_transform_batch(proposals1, gt_boxes) + rloss1se = self.regress_loss(reg1se, targets1se, iou1, iou_thres) + rloss1xw = self.regress_loss(reg1xw, targets1xw, iou1, iou_thres) + rloss1 = rloss1se + rloss1xw + iloss1 = self.iou_loss(preds_iou1, iou1, iou_thres=iou_thres) + else: + proposals2xw = bbox_xw_transform_inv(proposals1, reg1xw, 0.1, 0.2) + proposals2se = bbox_se_transform_inv(proposals1, reg1se, 0.2) + rloss1 = iloss1 = 0 + proposals2 = (proposals2se + proposals2xw) / 2.0 + proposals2 = torch.clamp(proposals2, min=0.) + return preds_iou1, proposals2, rloss1, iloss1 + + def regress_loss(self, regression, targets, iou_with_gt, iou_thres): + weight = (iou_with_gt >= iou_thres).float().unsqueeze(1) + reg_loss = F.smooth_l1_loss(regression, targets, reduction='none') + if weight.sum() > 0: + reg_loss = (weight * reg_loss).sum() / weight.sum() + else: + reg_loss = (weight * reg_loss).sum() + return reg_loss + + def iou_loss(self, preds_iou, match_iou, iou_thres): + preds_iou = preds_iou.view(-1) + u_hmask = (match_iou > iou_thres).float() + u_mmask = ((match_iou <= iou_thres) & (match_iou > 0.3)).float() + u_lmask = (match_iou <= 0.3).float() + + num_h, num_m, num_l = u_hmask.sum(), u_mmask.sum(), u_lmask.sum() + + bs, device = u_hmask.size()[0], u_hmask.device + + r_m = min(num_h / num_m, 1) + u_smmask = torch.rand(bs, device=device) * u_mmask + u_smmask = (u_smmask > (1. - r_m)).float() + + r_l = min(num_h / num_l, 1) + u_slmask = torch.rand(bs, device=device) * u_lmask + u_slmask = (u_slmask > (1. - r_l)).float() + + iou_weights = u_hmask + u_smmask + u_slmask + iou_loss = F.smooth_l1_loss(preds_iou, match_iou, reduction='none') + if iou_weights.sum() > 0: + iou_loss = (iou_loss * iou_weights).sum() / iou_weights.sum() + else: + iou_loss = (iou_loss * iou_weights).sum() + return iou_loss + + +@MODELS.register_module() +class TCANet(BaseModel): + """Temporal Context Aggregation Network. + + Please refer `Temporal Context Aggregation Network for Temporal Action + Proposal Refinement `_. + Code Reference: + https://github.com/qinzhi-0110/Temporal-Context-Aggregation-Network-Pytorch + """ + + def __init__(self, + feat_dim: int = 2304, + se_sample_num: int = 32, + action_sample_num: int = 64, + temporal_dim: int = 100, + window_size: int = 9, + lgte_num: int = 2, + soft_nms_alpha: float = 0.4, + soft_nms_low_threshold: float = 0.0, + soft_nms_high_threshold: float = 0.0, + post_process_top_k: int = 100, + feature_extraction_interval: int = 16, + init_cfg: OptConfigType = None, + **kwargs) -> None: + super(TCANet, self).__init__(init_cfg) + + self.soft_nms_alpha = soft_nms_alpha + self.soft_nms_low_threshold = soft_nms_low_threshold + self.soft_nms_high_threshold = soft_nms_high_threshold + self.feature_extraction_interval = feature_extraction_interval + self.post_process_top_k = post_process_top_k + + hidden_dim = 512 + self.x_1d_b_f = nn.Sequential( + nn.Conv1d( + feat_dim, hidden_dim, kernel_size=3, padding=1, groups=4), + nn.ReLU(inplace=True), + nn.Conv1d( + hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=4), + nn.ReLU(inplace=True), + ) + + for i in 1, 2, 3: + tbr = TBR( + se_sample_num=se_sample_num, + action_sample_num=action_sample_num, + temporal_dim=temporal_dim, + init_cfg=init_cfg, + **kwargs) + setattr(self, f'tbr{i}', tbr) + + self.lgtes = nn.ModuleList([ + LGTE( + input_dim=hidden_dim, + dropout=0.1, + temporal_dim=temporal_dim, + window_size=window_size, + init_cfg=init_cfg, + **kwargs) for i in range(lgte_num) + ]) + + def forward(self, inputs, data_samples, mode, **kwargs): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: + + - ``tensor``: Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - ``predict``: Forward and return the predictions, which are fully + processed to a list of :obj:`ActionDataSample`. + - ``loss``: Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[:obj:`ActionDataSample`], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``tensor``. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of ``ActionDataSample``. + - If ``mode="loss"``, return a dict of tensor. + """ + if not isinstance(input, Tensor): + inputs = torch.stack(inputs) + if mode == 'tensor': + return self._forward(inputs, **kwargs) + if mode == 'predict': + return self.predict(inputs, data_samples, **kwargs) + elif mode == 'loss': + return self.loss(inputs, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def _forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ + x = self.x_1d_b_f(x) + for layer in self.lgtes: + x = layer(x) + return x + + def loss(self, batch_inputs, batch_data_samples, **kwargs): + features = self._forward(batch_inputs) + gt_boxes = torch.stack( + [sample.gt_instances['gt_bbox'] for sample in batch_data_samples]) + proposals = torch.stack( + [sample.proposals['proposals'] for sample in batch_data_samples]) + + batch_size = proposals.size(0) + proposals_num = proposals.size(1) + for i in range(batch_size): + proposals[i, :, 2] = i + proposals = proposals.view(batch_size * proposals_num, 3) + proposals_select = proposals[:, 0:2].sum(dim=1) > 0 + proposals = proposals[proposals_select, :] + + features = features[proposals[:, 2].long()] + + gt_boxes = gt_boxes.view(batch_size * proposals_num, 2) + gt_boxes = gt_boxes[proposals_select, :] + + _, proposals1, rloss1, iloss1 = self.tbr1(proposals, features, + gt_boxes, 0.5, True) + _, proposals2, rloss2, iloss2 = self.tbr2(proposals1, features, + gt_boxes, 0.6, True) + _, _, rloss3, iloss3 = self.tbr3(proposals2, features, gt_boxes, 0.7, + True) + + loss_dict = dict( + rloss1=rloss1, + rloss2=rloss2, + rloss3=rloss3, + iloss1=iloss1, + iloss2=iloss2, + iloss3=iloss3) + return loss_dict + + def predict(self, batch_inputs, batch_data_samples, **kwargs): + features = self._forward(batch_inputs) + gt_boxes = torch.stack( + [sample.gt_instances['gt_bbox'] for sample in batch_data_samples]) + proposals = torch.stack( + [sample.proposals['proposals'] for sample in batch_data_samples]) + scores = torch.stack( + [sample.proposals['scores'] for sample in batch_data_samples]) + + batch_size = proposals.size(0) + proposals_num = proposals.size(1) + for i in range(batch_size): + proposals[i, :, 2] = i + proposals = proposals.view(batch_size * proposals_num, 3) + proposals_select = proposals[:, 0:2].sum(dim=1) > 0 + proposals = proposals[proposals_select, :] + + features = features[proposals[:, 2].long()] + + preds_iou1, proposals1, _, _ = self.tbr1(proposals, features, gt_boxes, + 0.5, False) + preds_iou2, proposals2, _, _ = self.tbr2(proposals1, features, + gt_boxes, 0.6, False) + preds_iou3, proposals3, _, _ = self.tbr3(proposals2, features, + gt_boxes, 0.7, False) + + all_proposals = [torch.cat([proposals, scores], dim=1)] + all_proposals += [torch.cat([proposals1, scores * preds_iou1], dim=1)] + all_proposals += [torch.cat([proposals2, scores * preds_iou2], dim=1)] + all_proposals += [torch.cat([proposals3, scores * preds_iou3], dim=1)] + + all_proposals = torch.stack(all_proposals).cpu().numpy() + + video_info = batch_data_samples[0].metainfo + proposal_list = post_processing(all_proposals, video_info, + self.soft_nms_alpha, + self.soft_nms_low_threshold, + self.soft_nms_high_threshold, + self.post_process_top_k, + self.feature_extraction_interval) + return proposal_list diff --git a/mmaction/models/localizers/utils/tcanet_utils.py b/mmaction/models/localizers/utils/tcanet_utils.py new file mode 100644 index 0000000000..33b35bcf89 --- /dev/null +++ b/mmaction/models/localizers/utils/tcanet_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copied from +# 'https://github.com/qinzhi-0110/' +# 'Temporal-Context-Aggregation-Network-Pytorch/' +# 'blob/main/utils.py' +# TODO: refactor +import torch + + +def batch_iou(proposals, gt_boxes): + len_proposals = proposals[:, 1] - proposals[:, 0] + int_xmin = torch.max(proposals[:, 0], gt_boxes[:, 0]) + int_xmax = torch.min(proposals[:, 1], gt_boxes[:, 1]) + inter_len = torch.clamp(int_xmax - int_xmin, min=0.) + union_len = len_proposals - inter_len + gt_boxes[:, 1] - gt_boxes[:, 0] + jaccard = inter_len / (union_len + 0.00001) + return jaccard + + +def bbox_xw_transform_inv(boxes, deltas, dx_w, dw_w): + widths = boxes[:, 1] - boxes[:, 0] + ctr_x = boxes[:, 0] + 0.5 * widths + + dx = deltas[:, 0] * dx_w + dw = deltas[:, 1] * dw_w + + pred_ctr_x = dx * widths + ctr_x + pred_w = torch.exp(dw) * widths + + pred_boxes = deltas.clone() + # x1 + pred_boxes[:, 0] = pred_ctr_x - 0.5 * pred_w + # x2 + pred_boxes[:, 1] = pred_ctr_x + 0.5 * pred_w + + return pred_boxes + + +def bbox_xw_transform_batch(ex_rois, gt_rois): + ex_widths = torch.clamp(ex_rois[:, 1] - ex_rois[:, 0], min=0.00001) + ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths + + gt_widths = torch.clamp(gt_rois[:, 1] - gt_rois[:, 0], min=0.00001) + gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths + + targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dw = torch.log(gt_widths / ex_widths) + targets = torch.stack((targets_dx, targets_dw), dim=1) + return targets + + +def bbox_se_transform_batch(ex_rois, gt_rois): + ex_widths = torch.clamp(ex_rois[:, 1] - ex_rois[:, 0], min=0.00001) + + s_offset = gt_rois[:, 0] - ex_rois[:, 0] + e_offset = gt_rois[:, 1] - ex_rois[:, 1] + + targets_s = s_offset / ex_widths + targets_e = e_offset / ex_widths + targets = torch.stack((targets_s, targets_e), dim=1) + return targets + + +def bbox_se_transform_inv(boxes, deltas, dse_w): + widths = boxes[:, 1] - boxes[:, 0] + s_offset = deltas[:, 0] * widths * dse_w + e_offset = deltas[:, 1] * widths * dse_w + pred_boxes = deltas.clone() + pred_boxes[:, 0] = boxes[:, 0] + s_offset + pred_boxes[:, 1] = boxes[:, 1] + e_offset + return pred_boxes