Skip to content

Commit

Permalink
[Feature] Support TCANet on HACS dataset (#2271)
Browse files Browse the repository at this point in the history
  • Loading branch information
hukkai committed Jun 19, 2023
1 parent 590febc commit 5f3b774
Show file tree
Hide file tree
Showing 5 changed files with 736 additions and 0 deletions.
66 changes: 66 additions & 0 deletions configs/localization/tcanet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# TCANet

[Temporal Context Aggregation Network for Temporal Action Proposal Refinement](https://openaccess.thecvf.com/content/CVPR2021/papers/Qing_Temporal_Context_Aggregation_Network_for_Temporal_Action_Proposal_Refinement_CVPR_2021_paper.pdf)

<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

Temporal action proposal generation aims to estimate temporal intervals of actions in untrimmed videos, which is a challenging yet important task in the video understanding field.
The proposals generated by current methods still suffer from inaccurate temporal boundaries and inferior confidence used for retrieval owing to the lack of efficient temporal modeling and effective boundary context utilization.
In this paper, we propose Temporal Context Aggregation Network (TCANet) to generate high-quality action proposals through `local and global` temporal context aggregation and complementary as well as progressive boundary refinement.
Specifically, we first design a Local-Global Temporal Encoder (LGTE), which adopts the channel grouping strategy to efficiently encode both `local and global` temporal inter-dependencies.
Furthermore, both the boundary and internal context of proposals are adopted for frame-level and segment-level boundary regressions, respectively.
Temporal Boundary Regressor (TBR) is designed to combine these two regression granularities in an end-to-end fashion, which achieves the precise boundaries and reliable confidence of proposals through progressive refinement. Extensive experiments are conducted on three challenging datasets: HACS, ActivityNet-v1.3, and THUMOS-14, where TCANet can generate proposals with high precision and recall. By combining with the existing action classifier, TCANet can obtain remarkable temporal action detection performance compared with other methods. Not surprisingly, the proposed TCANet won the 1$^{st}$ place in the CVPR 2020 - HACS challenge leaderboard on temporal action localization task.

<!-- [IMAGE] -->

<div align=center>
<img src="https://user-images.githubusercontent.com/35267818/223302449-8891241c-e84a-4c74-bf31-073d6a75b33a.png" width="800"/>
</div>

## Results and Models

### HACS dataset

| feature | gpus | pretrain | AUC | AR@1 | AR@5 | AR@10 | AR@100 | gpu_mem(M) | iter time(s) | config | ckpt | log |
| :------: | :--: | :------: | :---: | :---: | :---: | :---: | :----: | :--------: | :----------: | :-------------------------------------------: | :------------------------------------------: | :-----------------------------------------: |
| SlowOnly | 2 | None | 68.33 | 32.89 | 49.43 | 56.64 | 75.29 | 5412 | - | [config](/configs/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature_20230619-95fd88b0.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.log) |

For more details on data preparation, you can refer to [HACS Data Preparation](/tools/data/hacs/README.md).

## Train

Train TCANet model on HACS dataset with the SlowOnly feature.

```shell
bash tools/dist_train.sh configs/localization/tcanet/tcanet_2048x100_2x8_9e_hacs_feature.py 2
```

For more details, you can refer to the **Training** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md).

## Test

Test TCANet model on HACS dataset with the SlowOnly feature.

```shell
python3 tools/test.py configs/localization/tcanet/tcanet_2048x100_2x8_9e_hacs_feature.py CHECKPOINT.PTH
```

For more details, you can refer to the **Testing** part in the [Training and Test Tutorial](/docs/en/user_guides/train_test.md).

## Citation

<!-- [DATASET] -->

```BibTeX
@inproceedings{qing2021temporal,
title={Temporal Context Aggregation Network for Temporal Action Proposal Refinement},
author={Qing, Zhiwu and Su, Haisheng and Gan, Weihao and Wang, Dongliang and Wu, Wei and Wang, Xiang and Qiao, Yu and Yan, Junjie and Gao, Changxin and Sang, Nong},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={485--494},
year={2021}
}
```
121 changes: 121 additions & 0 deletions configs/localization/tcanet/tcanet_2xb8-2048x100-9e_hacs-feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
_base_ = '../../_base_/default_runtime.py'

# model settings
model = dict(
type='TCANet',
feat_dim=2048,
se_sample_num=32,
action_sample_num=64,
temporal_dim=100,
window_size=9,
lgte_num=2,
soft_nms_alpha=0.4,
soft_nms_low_threshold=0.0,
soft_nms_high_threshold=0.0,
post_process_top_k=100,
feature_extraction_interval=16)

# dataset settings
dataset_type = 'ActivityNetDataset'
data_root = 'data/HACS/slowonly_feature/'
data_root_val = 'data/HACS/slowonly_feature/'
ann_file_train = 'data/HACS/hacs_anno_train.json'
ann_file_val = 'data/HACS/hacs_anno_val.json'
ann_file_test = 'data/HACS/hacs_anno_val.json'

train_pipeline = [
dict(type='LoadLocalizationFeature'),
dict(type='GenerateLocalizationLabels'),
dict(
type='PackLocalizationInputs',
keys=('gt_bbox', ),
meta_keys=('video_name', ))
]

val_pipeline = [
dict(type='LoadLocalizationFeature'),
dict(type='GenerateLocalizationLabels'),
dict(
type='PackLocalizationInputs',
keys=('gt_bbox', ),
meta_keys=('video_name', 'duration_second', 'duration_frame',
'annotations', 'feature_frame'))
]

test_pipeline = [
dict(type='LoadLocalizationFeature'),
dict(
type='PackLocalizationInputs',
keys=('gt_bbox', ),
meta_keys=('video_name', 'duration_second', 'duration_frame',
'annotations', 'feature_frame'))
]

train_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
drop_last=True,
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=val_pipeline,
test_mode=True))

test_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=dict(video=data_root_val),
pipeline=test_pipeline,
test_mode=True))

max_epochs = 9
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_begin=1,
val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

optim_wrapper = dict(
optimizer=dict(type='Adam', lr=0.001, weight_decay=0.0001),
clip_grad=dict(max_norm=40, norm_type=2))

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[
7,
],
gamma=0.1)
]

work_dir = './work_dirs/tcanet_2xb8-2048x100-9e_hacs-feature/'
test_evaluator = dict(
type='ANetMetric',
metric_type='AR@AN',
dump_config=dict(out=f'{work_dir}/results.json', output_format='json'))
val_evaluator = test_evaluator
1 change: 1 addition & 0 deletions mmaction/models/localizers/bsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TEM(BaseModel):
Code reference
https://github.com/wzmsltw/BSN-boundary-sensitive-network
Args:
temporal_dim (int): Total frames selected for each video.
tem_feat_dim (int): Feature dimension.
tem_hidden_dim (int): Hidden layer dimension.
tem_match_threshold (float): Temporal evaluation match threshold.
Expand Down
Loading

0 comments on commit 5f3b774

Please sign in to comment.