Skip to content

Commit

Permalink
Merge pull request #208 from RangiLyu/refactor
Browse files Browse the repository at this point in the history
[Refactor] Switch to new training code based on pytorch lightning
  • Loading branch information
RangiLyu committed Apr 11, 2021
2 parents fb82fbb + 2ca0f48 commit 927143a
Show file tree
Hide file tree
Showing 15 changed files with 437 additions and 226 deletions.
60 changes: 33 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
* [2021.02.03] Support [EfficientNet-Lite](https://github.com/RangiLyu/EfficientNet-Lite) and [Rep-VGG](https://github.com/DingXiaoH/RepVGG) backbone. Please check the [config folder](config/). Download models in [Model Zoo](#model-zoo)

* [2021.01.10] **NanoDet-g** with lower memory access cost, which designed for edge NPU or GPU, is now available!
Check [config/nanodet-g.yml](config/nanodet-g.yml) and download:
[COCO pre-trained model(Google Drive)](https://drive.google.com/file/d/10uW7oqZKw231l_tr4C1bJWkbCXgBf7av/view?usp=sharing) | [(BaiduDisk百度网盘)](https://pan.baidu.com/s/1IJLdtLBvmQVOmzzNY_Ci5A) code:otcd
Check [config/nanodet-g.yml](config/nanodet-g.yml) and download in [Model Zoo](#model-zoo).

<details>
<summary>More...</summary>
Expand Down Expand Up @@ -93,9 +92,8 @@ Inference using [Alibaba's MNN framework](https://github.com/alibaba/MNN) is in
### Pytorch demo

First, install requirements and setup NanoDet following installation guide. Then download COCO pretrain weight from here
👉[COCO pretrain weight for torch>=1.6(Google Drive)](https://drive.google.com/file/d/1EhMqGozKfqEfw8y9ftbi1jhYu86XoW62/view?usp=sharing) | [(百度网盘)](https://pan.baidu.com/s/1LCnmj2Pqhv0tsDX__1j2gg) code:6au1

👉[COCO pretrain weight for torch<=1.5(Google Drive)](https://drive.google.com/file/d/10h-0qLMCgYvWQvKULqbkLvmirFR-w9NN/view?usp=sharing) | [(百度云盘)](https://pan.baidu.com/s/1OTcPiajCcqKLg3Q0vwho3A) code:topw
👉[COCO pretrain weight (Google Drive)](https://drive.google.com/file/d/1ZkYucuLusJrCb_i63Lid0kYyyLvEiGN3/view?usp=sharing)

* Inference images

Expand Down Expand Up @@ -141,13 +139,13 @@ Besides, We provide a notebook [here](./demo/demo-inference-with-pytorch.ipynb)
2. Install pytorch

```shell script
conda install pytorch torchvision cudatoolkit=11.0 -c pytorch
conda install pytorch torchvision cudatoolkit=11.1 -c pytorch
```

3. Install requirements

```shell script
pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm
pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics
```

4. Setup NanoDet
Expand All @@ -166,14 +164,14 @@ NanoDet supports variety of backbones. Go to the [***config*** folder](config/)

Model | Backbone |Resolution|COCO mAP| FLOPS |Params | Pre-train weight |
:--------------------:|:------------------:|:--------:|:------:|:-----:|:-----:|:-----:|
NanoDet-m | ShuffleNetV2 1.0x | 320*320 | 20.6 | 0.72B | 0.95M | [Download](https://drive.google.com/file/d/10h-0qLMCgYvWQvKULqbkLvmirFR-w9NN/view?usp=sharing) |
NanoDet-m-416 | ShuffleNetV2 1.0x | 416*416 | 23.5 | 1.2B | 0.95M | [Download](https://drive.google.com/file/d/1h6TBy1tx4faIBKHnYeg0QwzFF6wlFBEd/view?usp=sharing)|
NanoDet-t (***NEW***) | ShuffleNetV2 1.0x | 320*320 | 21.7 | 0.96B | 1.36M | [Download](https://drive.google.com/file/d/1O2iz-aaDiQHJNfocInpFrY8ZFMrT3M1r/view?usp=sharing) |
NanoDet-g | Custom CSP Net | 416*416 | 22.9 | 4.2B | 3.81M | [Download](https://drive.google.com/file/d/10uW7oqZKw231l_tr4C1bJWkbCXgBf7av/view?usp=sharing)|
NanoDet-EfficientLite | EfficientNet-Lite0 | 320*320 | 24.7 | 1.72B | 3.11M | [Download](https://drive.google.com/file/d/1u_t9L0jqjH858gCR-vpzWzu9FexQOSmJ/view?usp=sharing)|
NanoDet-EfficientLite | EfficientNet-Lite1 | 416*416 | 30.3 | 4.06B | 4.01M | [Download](https://drive.google.com/file/d/1y9z7BToAZOQ1pKbOjNjf79YMuFuDTvfq/view?usp=sharing) |
NanoDet-EfficientLite | EfficientNet-Lite2 | 512*512 | 32.6 | 7.12B | 4.71M | [Download](https://drive.google.com/file/d/1UMXJJxRkRzgTvN1iRKeDZqGpkLxK3X4K/view?usp=sharing) |
NanoDet-RepVGG | RepVGG-A0 | 416*416 | 27.8 | 11.3B | 6.75M | [Download](https://drive.google.com/file/d/1bsT9Ksxws2O3g_IUuUwp0QwZcJlqJw3S/view?usp=sharing) |
NanoDet-m | ShuffleNetV2 1.0x | 320*320 | 20.6 | 0.72B | 0.95M | [Download](https://drive.google.com/file/d/1ZkYucuLusJrCb_i63Lid0kYyyLvEiGN3/view?usp=sharing) |
NanoDet-m-416 | ShuffleNetV2 1.0x | 416*416 | 23.5 | 1.2B | 0.95M | [Download](https://drive.google.com/file/d/1jY-Um2VDDEhuVhluP9lE70rG83eXQYhV/view?usp=sharing)|
NanoDet-t (***NEW***) | ShuffleNetV2 1.0x | 320*320 | 21.7 | 0.96B | 1.36M | [Download](https://drive.google.com/file/d/1TqRGZeOKVCb98ehTaE0gJEuND6jxwaqN/view?usp=sharing) |
NanoDet-g | Custom CSP Net | 416*416 | 22.9 | 4.2B | 3.81M | [Download](https://drive.google.com/file/d/1f2lH7Ae1AY04g20zTZY7JS_dKKP37hvE/view?usp=sharing)|
NanoDet-EfficientLite | EfficientNet-Lite0 | 320*320 | 24.7 | 1.72B | 3.11M | [Download](https://drive.google.com/file/d/1Dj1nBFc78GHDI9Wn8b3X4MTiIV2el8qP/view?usp=sharing)|
NanoDet-EfficientLite | EfficientNet-Lite1 | 416*416 | 30.3 | 4.06B | 4.01M | [Download](https://drive.google.com/file/d/1ernkb_XhnKMPdCBBtUEdwxIIBF6UVnXq/view?usp=sharing) |
NanoDet-EfficientLite | EfficientNet-Lite2 | 512*512 | 32.6 | 7.12B | 4.71M | [Download](https://drive.google.com/file/d/11V20AxXe6bTHyw3aMkgsZVzLOB31seoc/view?usp=sharing) |
NanoDet-RepVGG | RepVGG-A0 | 416*416 | 27.8 | 11.3B | 6.75M | [Download](https://drive.google.com/file/d/1nWZZ1qXb1HuIXwPSYpEyFHHqX05GaFer/view?usp=sharing) |


****
Expand All @@ -194,35 +192,43 @@ NanoDet-RepVGG | RepVGG-A0 | 416*416 | 27.8 | 11.3B | 6.75M |

Change ***num_classes*** in ***model->arch->head***.

Change image path and annotation path in both ***data->train data->val***.
Change image path and annotation path in both ***data->train*** and ***data->val***.

Set gpu, workers and batch size in ***device*** to fit your device.
Set gpu ids, num workers and batch size in ***device*** to fit your device.

Set ***total_epochs***, ***lr*** and ***lr_schedule*** according to your dataset and batchsize.

If you want to modify network, data augmentation or other things, please refer to [Config File Detail](docs/config_file_detail.md)

3. **Start training**

For single GPU, run
NanoDet is now using [pytorch lightning](https://github.com/PyTorchLightning/pytorch-lightning) for training.

For both single-GPU or multiple-GPUs, run:

```shell script
python tools/train.py CONFIG_FILE_PATH
```

Old training script is deprecated and will be deleted in next version. If you still want to use,

<details>
<summary>follow this...</summary>

For single GPU, run

```shell script
python tools/train.py CONFIG_PATH
python tools/deprecated/train.py CONFIG_FILE_PATH
```

For multi-GPU, NanoDet using distributed training. (Notice: Windows not support distributed training before pytorch1.7) Please run

```shell script
python -m torch.distributed.launch --nproc_per_node=GPU_NUM --master_port 29501 tools/train.py CONFIG_PATH
python -m torch.distributed.launch --nproc_per_node=GPU_NUM --master_port 29501 tools/deprecated/train.py CONFIG_FILE_PATH
```

**Experimental**:

Training with [pytorch lightning](https://github.com/PyTorchLightning/pytorch-lightning), no matter single or multi GPU just run:

```shell script
python tools/train_pl.py CONFIG_PATH
```
</details>


4. **Visualize Logs**

Expand All @@ -232,7 +238,7 @@ NanoDet-RepVGG | RepVGG-A0 | 416*416 | 27.8 | 11.3B | 6.75M |

```shell script
cd <YOUR_SAVE_DIR>
tensorboard --logdir ./logs
tensorboard --logdir ./
```

****
Expand Down
2 changes: 1 addition & 1 deletion docs/config_file_detail.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Change save_dir to where you want to save logs and models. If path not exist, Na
```yaml
model:
arch:
name: xxx
name: OneStageDetector
backbone: xxx
fpn: xxx
head: xxx
Expand Down
3 changes: 1 addition & 2 deletions nanodet/evaluator/coco_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def results2json(self, results):
json_results.append(detection)
return json_results

def evaluate(self, results, save_dir, epoch, logger, rank=-1):
def evaluate(self, results, save_dir, rank=-1):
results_json = self.results2json(results)
json_path = os.path.join(save_dir, 'results{}.json'.format(rank))
json.dump(results_json, open(json_path, 'w'))
Expand All @@ -61,5 +61,4 @@ def evaluate(self, results, save_dir, epoch, logger, rank=-1):
eval_results = {}
for k, v in zip(self.metric_names, aps):
eval_results[k] = v
logger.scalar_summary('Val_coco_bbox/' + k, 'val', v, epoch)
return eval_results
81 changes: 67 additions & 14 deletions nanodet/trainer/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import os
import warnings
import json
import torch
import logging
from pytorch_lightning import LightningModule
Expand All @@ -27,25 +28,20 @@
class TrainingTask(LightningModule):
"""
Pytorch Lightning module of a general training task.
Including training, evaluating and testing.
Args:
cfg: Training configurations
evaluator: Evaluator for evaluating the model performance.
"""

def __init__(self, cfg, evaluator=None, logger=None):
"""
Args:
cfg: Training configurations
evaluator:
logger:
"""
def __init__(self, cfg, evaluator=None):
super(TrainingTask, self).__init__()
self.cfg = cfg
self.model = build_model(cfg.model)
self.evaluator = evaluator
self._logger = logger
self.save_flag = -10
self.log_style = 'NanoDet' # Log style. Choose between 'NanoDet' or 'Lightning'
# TODO: use callback to log
# TODO: remove _logger
# TODO: batch eval
# TODO: support old checkpoint

Expand All @@ -54,7 +50,7 @@ def forward(self, x):
return x

@torch.no_grad()
def predict(self, batch, batch_idx, dataloader_idx):
def predict(self, batch, batch_idx=None, dataloader_idx=None):
preds = self.forward(batch['img'])
results = self.model.head.post_process(preds, batch)
return results
Expand Down Expand Up @@ -103,11 +99,17 @@ def validation_step(self, batch, batch_idx):
return res

def validation_epoch_end(self, validation_step_outputs):
"""
Called at the end of the validation epoch with the outputs of all validation steps.
Evaluating results and save best model.
Args:
validation_step_outputs: A list of val outputs
"""
results = {}
for res in validation_step_outputs:
results.update(res)
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, self.current_epoch+1,
self._logger, rank=self.local_rank)
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, rank=self.local_rank)
metric = eval_results[self.cfg.evaluator.save_key]
# save best model
if metric > self.save_flag:
Expand All @@ -125,9 +127,39 @@ def validation_epoch_end(self, validation_step_outputs):
warnings.warn('Warning! Save_key is not in eval results! Only save model last!')
if self.log_style == 'Lightning':
for k, v in eval_results.items():
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
self.log('Val_metrics/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
elif self.log_style == 'NanoDet':
for k, v in eval_results.items():
self.scalar_summary('Val_metrics/' + k, 'Val', v, self.current_epoch+1)

def test_step(self, batch, batch_idx):
dets = self.predict(batch, batch_idx)
res = {batch['img_info']['id'].cpu().numpy()[0]: dets}
return res

def test_epoch_end(self, test_step_outputs):
results = {}
for res in test_step_outputs:
results.update(res)
res_json = self.evaluator.results2json(results)
json_path = os.path.join(self.cfg.save_dir, 'results.json')
json.dump(res_json, open(json_path, 'w'))

if self.cfg.test_mode == 'val':
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, rank=self.local_rank)
txt_path = os.path.join(self.cfg.save_dir, "eval_results.txt")
with open(txt_path, "a") as f:
for k, v in eval_results.items():
f.write("{}: {}\n".format(k, v))

def configure_optimizers(self):
"""
Prepare optimizer and learning-rate scheduler
to use in optimization.
Returns:
optimizer
"""
optimizer_cfg = copy.deepcopy(self.cfg.schedule.optimizer)
name = optimizer_cfg.pop('name')
build_optimizer = getattr(torch.optim, name)
Expand All @@ -153,6 +185,18 @@ def optimizer_step(self,
on_tpu=None,
using_native_amp=None,
using_lbfgs=None):
"""
Performs a single optimization step (parameter update).
Args:
epoch: Current epoch
batch_idx: Index of current batch
optimizer: A PyTorch optimizer
optimizer_idx: If you used multiple optimizers this indexes into that list.
optimizer_closure: closure for all optimizers
on_tpu: true if TPU backward is required
using_native_amp: True if using native amp
using_lbfgs: True if the matching optimizer is lbfgs
"""
# warm up lr
if self.trainer.global_step <= self.cfg.schedule.warmup.steps:
if self.cfg.schedule.warmup.name == 'constant':
Expand Down Expand Up @@ -180,6 +224,15 @@ def get_progress_bar_dict(self):
return items

def scalar_summary(self, tag, phase, value, step):
"""
Write Tensorboard scalar summary log.
Args:
tag: Name for the tag
phase: 'Train' or 'Val'
value: Value to record
step: Step value to record
"""
if self.local_rank < 1:
self.logger.experiment.add_scalars(tag, {phase: value}, step)

Expand Down
4 changes: 3 additions & 1 deletion nanodet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def run(self, train_loader, val_loader, evaluator):
results, val_loss_dict = self.run_epoch(self.epoch, val_loader, mode='val')
for k, v in val_loss_dict.items():
self.logger.scalar_summary('Epoch_loss/' + k, 'val', v, epoch)
eval_results = evaluator.evaluate(results, self.cfg.save_dir, epoch, self.logger, rank=self.rank)
eval_results = evaluator.evaluate(results, self.cfg.save_dir, rank=self.rank)
for k, v in eval_results.items():
self.logger.scalar_summary('Val_metrics/' + k, 'val', v, epoch)
if self.cfg.evaluator.save_key in eval_results:
metric = eval_results[self.cfg.evaluator.save_key]
if metric > save_flag:
Expand Down
2 changes: 1 addition & 1 deletion nanodet/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .logger import Logger, MovingAverage, AverageMeter
from .data_parallel import DataParallel
from .distributed_data_parallel import DDP
from .check_point import load_model_weight, save_model
from .check_point import load_model_weight, save_model, convert_old_model
from .config import cfg, load_config
from .box_transform import *
from .util_mixins import NiceRepr
Expand Down
29 changes: 29 additions & 0 deletions nanodet/util/check_point.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import torch
import pytorch_lightning as pl
from collections import OrderedDict
from .rank_filter import rank_filter


def load_model_weight(model, checkpoint, logger):
state_dict = checkpoint['state_dict']
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
if list(state_dict.keys())[0].startswith('model.'):
state_dict = {k[6:]: v for k, v in checkpoint['state_dict'].items()}

model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()

Expand Down Expand Up @@ -35,3 +40,27 @@ def save_model(model, path, epoch, iter, optimizer=None):
data['optimizer'] = optimizer.state_dict()

torch.save(data, path)


def convert_old_model(old_model_dict):
if 'pytorch-lightning_version' in old_model_dict:
raise ValueError('This model is not old format. No need to convert!')
version = pl.__version__
epoch = old_model_dict['epoch']
global_step = old_model_dict['iter']
state_dict = old_model_dict['state_dict']
new_state_dict = OrderedDict()
for name, value in state_dict.items():
new_state_dict['model.' + name] = value

new_checkpoint = {'epoch': epoch,
'global_step': global_step,
'pytorch-lightning_version': version,
'state_dict': new_state_dict,
'lr_schedulers': []}

if 'optimizer' in old_model_dict:
optimizer_states = [old_model_dict['optimizer']]
new_checkpoint['optimizer_states'] = optimizer_states

return new_checkpoint
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
from setuptools import find_packages, setup

__version__ = "0.2.1"
__version__ = "0.3.0"

if __name__ == '__main__':
setup(
Expand Down
Loading

0 comments on commit 927143a

Please sign in to comment.