Skip to content

Commit

Permalink
Change API to pytorch 0.4; Also, support cpu during demo and test tim…
Browse files Browse the repository at this point in the history
…e (Doesn't make sense to support cpu during training.).
  • Loading branch information
ruotianluo committed Apr 25, 2018
1 parent fa88df8 commit 7fd5263
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 140 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Additional features not mentioned in the [report](https://arxiv.org/pdf/1702.021
- **Support for visualization**. The current implementation will summarize ground truth boxes, statistics of losses, activations and variables during training, and dump it to a separate folder for tensorboard visualization. The computing graph is also saved for debugging.

### Prerequisites
- A basic pytorch installation. The code follows **0.3**. If you are using old **0.1.12** or **0.2**, you can checkout the corresponding branch.
- A basic pytorch installation. The code follows **0.4**. If you are using old **0.1.12** or **0.2** or **0.3**, you can checkout the corresponding branch.
- Python packages you might not have: `cffi`, `opencv-python`, `easydict` (similar to [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn)). For `easydict` make sure you have the right version. Xinlei uses 1.6.
- [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) to visualize the training and validation curve. Please build from source to use the latest tensorflow-tensorboard.
- ~~Docker users: Since the recent upgrade, the docker image on docker hub (https://hub.docker.com/r/mbuckler/tf-faster-rcnn-deps/) is no longer valid. However, you can still build your own image by using dockerfile located at `docker` folder (cuda 8 version, as it is required by Tensorflow r1.0.) And make sure following Tensorflow installation to install and use nvidia-docker[https://github.com/NVIDIA/nvidia-docker]. Last, after launching the container, you have to build the Cython modules within the running container.~~
Expand Down
4 changes: 1 addition & 3 deletions lib/layer_utils/proposal_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from model.nms_wrapper import nms

import torch
from torch.autograd import Variable


def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride, anchors, num_anchors):
"""A simplified version compared to fast/er RCNN
Expand Down Expand Up @@ -50,7 +48,7 @@ def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride,
scores = scores[keep,]

# Only support single image as input
batch_inds = Variable(proposals.data.new(proposals.size(0), 1).zero_())
batch_inds = proposals.new_zeros(proposals.size(0), 1)
blob = torch.cat((batch_inds, proposals), 1)

return blob, scores
17 changes: 8 additions & 9 deletions lib/layer_utils/proposal_target_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import torch
from torch.autograd import Variable

def proposal_target_layer(rpn_rois, rpn_scores, gt_boxes, _num_classes):
"""
Expand All @@ -31,7 +30,7 @@ def proposal_target_layer(rpn_rois, rpn_scores, gt_boxes, _num_classes):

# Include ground-truth boxes in the set of candidate rois
if cfg.TRAIN.USE_GT:
zeros = rpn_rois.data.new(gt_boxes.shape[0], 1)
zeros = rpn_rois.new_zeros(gt_boxes.shape[0], 1)
all_rois = torch.cat(
(all_rois, torch.cat((zeros, gt_boxes[:, :-1]), 1))
, 0)
Expand All @@ -55,7 +54,7 @@ def proposal_target_layer(rpn_rois, rpn_scores, gt_boxes, _num_classes):
bbox_inside_weights = bbox_inside_weights.view(-1, _num_classes * 4)
bbox_outside_weights = (bbox_inside_weights > 0).float()

return rois, roi_scores, labels, Variable(bbox_targets), Variable(bbox_inside_weights), Variable(bbox_outside_weights)
return rois, roi_scores, labels, bbox_targets, bbox_inside_weights, bbox_outside_weights


def _get_bbox_regression_labels(bbox_target_data, num_classes):
Expand All @@ -72,8 +71,8 @@ def _get_bbox_regression_labels(bbox_target_data, num_classes):
# Inputs are tensor

clss = bbox_target_data[:, 0]
bbox_targets = clss.new(clss.numel(), 4 * num_classes).zero_()
bbox_inside_weights = clss.new(bbox_targets.shape).zero_()
bbox_targets = clss.new_zeros(clss.numel(), 4 * num_classes)
bbox_inside_weights = clss.new_zeros(bbox_targets.shape)
inds = (clss > 0).nonzero().view(-1)
if inds.numel() > 0:
clss = clss[inds].contiguous().view(-1,1)
Expand Down Expand Up @@ -122,17 +121,17 @@ def _sample_rois(all_rois, all_scores, gt_boxes, fg_rois_per_image, rois_per_ima
# Small modification to the original version where we ensure a fixed number of regions are sampled
if fg_inds.numel() > 0 and bg_inds.numel() > 0:
fg_rois_per_image = min(fg_rois_per_image, fg_inds.numel())
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image), replace=False)).long().cuda()]
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(fg_rois_per_image), replace=False)).long().to(gt_boxes.device)]
bg_rois_per_image = rois_per_image - fg_rois_per_image
to_replace = bg_inds.numel() < bg_rois_per_image
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(bg_rois_per_image), replace=to_replace)).long().cuda()]
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(bg_rois_per_image), replace=to_replace)).long().to(gt_boxes.device)]
elif fg_inds.numel() > 0:
to_replace = fg_inds.numel() < rois_per_image
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().cuda()]
fg_inds = fg_inds[torch.from_numpy(npr.choice(np.arange(0, fg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().to(gt_boxes.device)]
fg_rois_per_image = rois_per_image
elif bg_inds.numel() > 0:
to_replace = bg_inds.numel() < rois_per_image
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().cuda()]
bg_inds = bg_inds[torch.from_numpy(npr.choice(np.arange(0, bg_inds.numel()), size=int(rois_per_image), replace=to_replace)).long().to(gt_boxes.device)]
fg_rois_per_image = 0
else:
import pdb
Expand Down
4 changes: 2 additions & 2 deletions lib/layer_utils/proposal_top_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, ancho
if length < rpn_top_n:
# Random selection, maybe unnecessary and loses good proposals
# But such case rarely happens
top_inds = torch.from_numpy(npr.choice(length, size=rpn_top_n, replace=True)).long().cuda()
top_inds = torch.from_numpy(npr.choice(length, size=rpn_top_n, replace=True)).long().to(anchors.device)
else:
top_inds = scores.sort(0, descending=True)[1]
top_inds = top_inds[:rpn_top_n]
Expand All @@ -50,6 +50,6 @@ def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, ancho
# Output rois blob
# Our RPN implementation only supports a single input image, so all
# batch inds are 0
batch_inds = proposals.data.new(proposals.size(0), 1).zero_()
batch_inds = proposals.new_zeros(proposals.size(0), 1)
blob = torch.cat([batch_inds, proposals], 1)
return blob, scores
48 changes: 0 additions & 48 deletions lib/layer_utils/roi_pooling/roi_pool_py.py

This file was deleted.

2 changes: 1 addition & 1 deletion lib/model/train_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def train_model(self, max_iters):
next_stepsize = stepsizes.pop()

self.net.train()
self.net.cuda()
self.net.to(self.net._device)

while iter < max_iters + 1:
# Learning rate
Expand Down
46 changes: 24 additions & 22 deletions lib/nets/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self):
self._event_summaries = {}
self._image_gt_summaries = {}
self._variables_to_fix = {}
self._device = 'cuda'

def _add_gt_image(self):
# add back mean
Expand Down Expand Up @@ -125,10 +126,10 @@ def _anchor_target_layer(self, rpn_cls_score):
anchor_target_layer(
rpn_cls_score.data, self._gt_boxes.data.cpu().numpy(), self._im_info, self._feat_stride, self._anchors.data.cpu().numpy(), self._num_anchors)

rpn_labels = Variable(torch.from_numpy(rpn_labels).float().cuda()) #.set_shape([1, 1, None, None])
rpn_bbox_targets = Variable(torch.from_numpy(rpn_bbox_targets).float().cuda())#.set_shape([1, None, None, self._num_anchors * 4])
rpn_bbox_inside_weights = Variable(torch.from_numpy(rpn_bbox_inside_weights).float().cuda())#.set_shape([1, None, None, self._num_anchors * 4])
rpn_bbox_outside_weights = Variable(torch.from_numpy(rpn_bbox_outside_weights).float().cuda())#.set_shape([1, None, None, self._num_anchors * 4])
rpn_labels = torch.from_numpy(rpn_labels).float().to(self._device) #.set_shape([1, 1, None, None])
rpn_bbox_targets = torch.from_numpy(rpn_bbox_targets).float().to(self._device)#.set_shape([1, None, None, self._num_anchors * 4])
rpn_bbox_inside_weights = torch.from_numpy(rpn_bbox_inside_weights).float().to(self._device)#.set_shape([1, None, None, self._num_anchors * 4])
rpn_bbox_outside_weights = torch.from_numpy(rpn_bbox_outside_weights).float().to(self._device)#.set_shape([1, None, None, self._num_anchors * 4])

rpn_labels = rpn_labels.long()
self._anchor_targets['rpn_labels'] = rpn_labels
Expand Down Expand Up @@ -164,7 +165,7 @@ def _anchor_component(self, height, width):
anchors, anchor_length = generate_anchors_pre(\
height, width,
self._feat_stride, self._anchor_scales, self._anchor_ratios)
self._anchors = Variable(torch.from_numpy(anchors).cuda())
self._anchors = torch.from_numpy(anchors).to(self._device)
self._anchor_length = anchor_length

def _smooth_l1_loss(self, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights, sigma=1.0, dim=[1]):
Expand All @@ -186,7 +187,7 @@ def _add_losses(self, sigma_rpn=3.0):
# RPN, class loss
rpn_cls_score = self._predictions['rpn_cls_score_reshape'].view(-1, 2)
rpn_label = self._anchor_targets['rpn_labels'].view(-1)
rpn_select = Variable((rpn_label.data != -1).nonzero().view(-1))
rpn_select = (rpn_label.data != -1).nonzero().view(-1)
rpn_cls_score = rpn_cls_score.index_select(0, rpn_select).contiguous().view(-1, 2)
rpn_label = rpn_label.index_select(0, rpn_select).contiguous().view(-1)
rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)
Expand Down Expand Up @@ -325,7 +326,7 @@ def _run_summary_op(self, val=False):
summaries.append(self._add_gt_image_summary())
# Add event_summaries
for key, var in self._event_summaries.items():
summaries.append(tb.summary.scalar(key, var.data[0]))
summaries.append(tb.summary.scalar(key, var.item()))
self._event_summaries = {}
if not val:
# Add score summaries
Expand Down Expand Up @@ -375,9 +376,9 @@ def forward(self, image, im_info, gt_boxes=None, mode='TRAIN'):
self._image_gt_summaries['gt_boxes'] = gt_boxes
self._image_gt_summaries['im_info'] = im_info

self._image = Variable(torch.from_numpy(image.transpose([0,3,1,2])).cuda(), volatile=mode == 'TEST')
self._image = torch.from_numpy(image.transpose([0,3,1,2])).to(self._device)
self._im_info = im_info # No need to change; actually it can be an list
self._gt_boxes = Variable(torch.from_numpy(gt_boxes).cuda()) if gt_boxes is not None else None
self._gt_boxes = torch.from_numpy(gt_boxes).to(self._device) if gt_boxes is not None else None

self._mode = mode

Expand All @@ -386,7 +387,7 @@ def forward(self, image, im_info, gt_boxes=None, mode='TRAIN'):
if mode == 'TEST':
stds = bbox_pred.data.new(cfg.TRAIN.BBOX_NORMALIZE_STDS).repeat(self._num_classes).unsqueeze(0).expand_as(bbox_pred)
means = bbox_pred.data.new(cfg.TRAIN.BBOX_NORMALIZE_MEANS).repeat(self._num_classes).unsqueeze(0).expand_as(bbox_pred)
self._predictions["bbox_pred"] = bbox_pred.mul(Variable(stds)).add(Variable(means))
self._predictions["bbox_pred"] = bbox_pred.mul(stds).add(means)
else:
self._add_losses() # compute losses

Expand All @@ -411,13 +412,14 @@ def normal_init(m, mean, stddev, truncated=False):
# Extract the head feature maps, for example for vgg16 it is conv5_3
# only useful during testing mode
def extract_head(self, image):
feat = self._layers["head"](Variable(torch.from_numpy(image.transpose([0,3,1,2])).cuda(), volatile=True))
feat = self._layers["head"](torch.from_numpy(image.transpose([0,3,1,2])).to(self._device))
return feat

# only useful during testing mode
def test_image(self, image, im_info):
self.eval()
self.forward(image, im_info, None, mode='TEST')
with torch.no_grad():
self.forward(image, im_info, None, mode='TEST')
cls_score, cls_prob, bbox_pred, rois = self._predictions["cls_score"].data.cpu().numpy(), \
self._predictions['cls_prob'].data.cpu().numpy(), \
self._predictions['bbox_pred'].data.cpu().numpy(), \
Expand All @@ -440,11 +442,11 @@ def get_summary(self, blobs):

def train_step(self, blobs, train_op):
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].data[0], \
self._losses['rpn_loss_box'].data[0], \
self._losses['cross_entropy'].data[0], \
self._losses['loss_box'].data[0], \
self._losses['total_loss'].data[0]
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].item(), \
self._losses['rpn_loss_box'].item(), \
self._losses['cross_entropy'].item(), \
self._losses['loss_box'].item(), \
self._losses['total_loss'].item()
#utils.timer.timer.tic('backward')
train_op.zero_grad()
self._losses['total_loss'].backward()
Expand All @@ -457,11 +459,11 @@ def train_step(self, blobs, train_op):

def train_step_with_summary(self, blobs, train_op):
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].data[0], \
self._losses['rpn_loss_box'].data[0], \
self._losses['cross_entropy'].data[0], \
self._losses['loss_box'].data[0], \
self._losses['total_loss'].data[0]
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].item(), \
self._losses['rpn_loss_box'].item(), \
self._losses['cross_entropy'].item(), \
self._losses['loss_box'].item(), \
self._losses['total_loss'].item()
train_op.zero_grad()
self._losses['total_loss'].backward()
train_op.step()
Expand Down
6 changes: 4 additions & 2 deletions lib/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def __init__(self):
def tic(self, name='default'):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
self._start_time[name] = time.time()

def toc(self, name='default', average=True):
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
self._diff[name] = time.time() - self._start_time[name]
self._total_time[name] = self._total_time.get(name, 0.) + self._diff[name]
self._calls[name] = self._calls.get(name, 0 ) + 1
Expand Down
103 changes: 55 additions & 48 deletions tools/demo.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def parse_args():
net.create_architecture(21,
tag='default', anchor_scales=[8, 16, 32])

net.load_state_dict(torch.load(saved_model))
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))

net.eval()
net.cuda()
if not torch.cuda.is_available():
net._device = 'cpu'
net.to(net._device)

print('Loaded network {:s}'.format(saved_model))

Expand Down
6 changes: 4 additions & 2 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def parse_args():
anchor_ratios=cfg.ANCHOR_RATIOS)

net.eval()
net.cuda()
if not torch.cuda.is_available():
net._device = 'cpu'
net.to(net._device)

if args.model:
print(('Loading model check point from {:s}').format(args.model))
net.load_state_dict(torch.load(args.model))
net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
print('Loaded.')
else:
print(('Loading initial weights from {:s}').format(args.weight))
Expand Down

0 comments on commit 7fd5263

Please sign in to comment.