Skip to content

Commit

Permalink
Merge pull request #205 from RangiLyu/refactor
Browse files Browse the repository at this point in the history
[Refactor] replace lightning log with old style log
  • Loading branch information
RangiLyu committed Apr 8, 2021
2 parents 8d5f011 + 1dc9405 commit fb82fbb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
61 changes: 48 additions & 13 deletions nanodet/trainer/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import warnings
import torch
import logging
from pytorch_lightning import LightningModule
from typing import Any, List, Dict, Tuple, Optional

Expand All @@ -42,8 +43,11 @@ def __init__(self, cfg, evaluator=None, logger=None):
self.evaluator = evaluator
self._logger = logger
self.save_flag = -10
# TODO: better logger
self.log_style = 'NanoDet' # Log style. Choose between 'NanoDet' or 'Lightning'
# TODO: use callback to log
# TODO: remove _logger
# TODO: batch eval
# TODO: support old checkpoint

def forward(self, x):
x = self.model(x)
Expand All @@ -57,21 +61,43 @@ def predict(self, batch, batch_idx, dataloader_idx):

def training_step(self, batch, batch_idx):
preds, loss, loss_states = self.model.forward_train(batch)
self.log('lr', self.optimizers().param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=True)
for k, v in loss_states.items():
self.log('Train/'+k, v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

# log train losses
if self.log_style == 'Lightning':
self.log('lr', self.optimizers().param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=True)
for k, v in loss_states.items():
self.log('Train/'+k, v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
elif self.log_style == 'NanoDet' and self.global_step % self.cfg.log.interval == 0:
lr = self.optimizers().param_groups[0]['lr']
log_msg = 'Train|Epoch{}/{}|Iter{}({})| lr:{:.2e}| '.format(self.current_epoch+1,
self.cfg.schedule.total_epochs, self.global_step, batch_idx, lr)
self.scalar_summary('Train_loss/lr', 'Train', lr, self.global_step)
for l in loss_states:
log_msg += '{}:{:.4f}| '.format(l, loss_states[l].mean().item())
self.scalar_summary('Train_loss/' + l, 'Train', loss_states[l].mean().item(), self.global_step)
self.info(log_msg)

return loss

def training_epoch_end(self, outputs: List[Any]) -> None:
self.print('Epoch ', self.current_epoch, ' finished.')
self.trainer.save_checkpoint(os.path.join(self.cfg.save_dir, 'model_last.ckpt'))
self.lr_scheduler.step()

def validation_step(self, batch, batch_idx):
preds, loss, loss_states = self.model.forward_train(batch)
self.log('Val/loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=False)
for k, v in loss_states.items():
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)

if self.log_style == 'Lightning':
self.log('Val/loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=False)
for k, v in loss_states.items():
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
elif self.log_style == 'NanoDet' and batch_idx % self.cfg.log.interval == 0:
lr = self.optimizers().param_groups[0]['lr']
log_msg = 'Val|Epoch{}/{}|Iter{}({})| lr:{:.2e}| '.format(self.current_epoch+1,
self.cfg.schedule.total_epochs, self.global_step, batch_idx, lr)
for l in loss_states:
log_msg += '{}:{:.4f}| '.format(l, loss_states[l].mean().item())
self.info(log_msg)

dets = self.model.head.post_process(preds, batch)
res = {batch['img_info']['id'].cpu().numpy()[0]: dets}
return res
Expand All @@ -80,7 +106,7 @@ def validation_epoch_end(self, validation_step_outputs):
results = {}
for res in validation_step_outputs:
results.update(res)
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, self.current_epoch,
eval_results = self.evaluator.evaluate(results, self.cfg.save_dir, self.current_epoch+1,
self._logger, rank=self.local_rank)
metric = eval_results[self.cfg.evaluator.save_key]
# save best model
Expand All @@ -97,8 +123,9 @@ def validation_epoch_end(self, validation_step_outputs):
f.write("{}: {}\n".format(k, v))
else:
warnings.warn('Warning! Save_key is not in eval results! Only save model last!')
for k, v in eval_results.items():
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
if self.log_style == 'Lightning':
for k, v in eval_results.items():
self.log('Val/' + k, v, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)

def configure_optimizers(self):
optimizer_cfg = copy.deepcopy(self.cfg.schedule.optimizer)
Expand Down Expand Up @@ -140,8 +167,6 @@ def optimizer_step(self,
raise Exception('Unsupported warm up type!')
for pg in optimizer.param_groups:
pg['lr'] = warmup_lr
# TODO: log lr to tensorboard
# self.log('lr', optimizer.param_groups[0]['lr'], on_step=True, on_epoch=True, prog_bar=True)

# update params
optimizer.step(closure=optimizer_closure)
Expand All @@ -151,8 +176,18 @@ def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()
items.pop("v_num", None)
items.pop("loss", None)
return items

def scalar_summary(self, tag, phase, value, step):
if self.local_rank < 1:
self.logger.experiment.add_scalars(tag, {phase: value}, step)

def info(self, string):
if self.local_rank < 1:
logging.info(string)





Expand Down
4 changes: 3 additions & 1 deletion tools/train_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ProgressBar

from nanodet.util import mkdir, Logger, cfg, load_config
from nanodet.data.collate import collate_function
Expand Down Expand Up @@ -91,7 +92,8 @@ def main(args):
accelerator='ddp',
log_every_n_steps=cfg.log.interval,
num_sanity_val_steps=0,
resume_from_checkpoint=model_resume_path
resume_from_checkpoint=model_resume_path,
callbacks=[ProgressBar(refresh_rate=0)] # disable tqdm bar
)

trainer.fit(task, train_dataloader, val_dataloader)
Expand Down

0 comments on commit fb82fbb

Please sign in to comment.