Skip to content

Commit

Permalink
Adds squeeze and excitation (scSE) modules, resolves mapbox#157
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed May 29, 2019
1 parent 54e20dc commit 57da3ec
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
63 changes: 63 additions & 0 deletions robosat/scse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Squeeze and Excitation blocks - attention for classification and segmentation
See:
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks
- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks
"""

import torch
import torch.nn as nn


class SpatialSqChannelEx(nn.Module):
"""Spatial Squeeze and Channel Excitation (cSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 b
"""

def __init__(self, num_in, r):
super().__init__()
self.fc0 = Conv1x1(num_in, num_in // r)
self.fc1 = Conv1x1(num_in // r, num_in)

def forward(self, x):
xx = nn.functional.adaptive_avg_pool2d(x, 1)
xx = self.fc0(xx)
xx = nn.functional.relu(xx, inplace=True)
xx = self.fc1(xx)
xx = torch.sigmoid(xx)
return x * xx


class ChannelSqSpatialEx(nn.Module):
"""Channel Squeeze and Spatial Excitation (sSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 c
"""

def __init__(self, num_in):
super().__init__()
self.conv = Conv1x1(num_in, 1)

def forward(self, x):
xx = self.conv(x)
xx = torch.sigmoid(xx)
return x * xx


class SpatialChannelSqChannelEx(nn.Module):
"""Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 d
"""

def __init__(self, num_in, r=16):
super().__init__()

self.cse = SpatialSqChannelEx(num_in, r)
self.sse = ChannelSqSpatialEx(num_in)

def forward(self, x):
return self.cse(x) + self.sse(x)


def Conv1x1(num_in, num_out):
return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)
38 changes: 27 additions & 11 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from torchvision.models import resnet50

from robosat.scse import SpatialChannelSqChannelEx


class ConvRelu(nn.Module):
"""3x3 convolution followed by ReLU activation building block.
Expand Down Expand Up @@ -91,10 +93,23 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):

# Todo: make input channels configurable, not hard-coded to three channels for RGB

self.resnet = resnet50(pretrained=pretrained)

# Access resnet directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392
self.resnet = resnet50(pretrained=pretrained)

# seSE blocks to append to encoder and decoder as recommended by
# https://arxiv.org/abs/1803.02579
self.scse0 = SpatialChannelSqChannelEx(64)
self.scse1 = SpatialChannelSqChannelEx(256)
self.scse2 = SpatialChannelSqChannelEx(512)
self.scse3 = SpatialChannelSqChannelEx(1024)
self.scse4 = SpatialChannelSqChannelEx(2048)

self.scse5 = SpatialChannelSqChannelEx(num_filters * 8)
self.scse6 = SpatialChannelSqChannelEx(num_filters * 8)
self.scse7 = SpatialChannelSqChannelEx(num_filters * 2)
self.scse8 = SpatialChannelSqChannelEx(num_filters * 2 * 2)
self.scse9 = SpatialChannelSqChannelEx(num_filters)

self.center = DecoderBlock(2048, num_filters * 8)

Expand Down Expand Up @@ -122,20 +137,21 @@ def forward(self, x):
enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.scse0(enc0)
enc0 = self.resnet.maxpool(enc0)

enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)
enc1 = self.scse1(self.resnet.layer1(enc0))
enc2 = self.scse2(self.resnet.layer2(enc1))
enc3 = self.scse3(self.resnet.layer3(enc2))
enc4 = self.scse4(self.resnet.layer4(enc3))

center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))

dec0 = self.dec0(torch.cat([enc4, center], dim=1))
dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))
dec2 = self.dec2(torch.cat([enc2, dec1], dim=1))
dec3 = self.dec3(torch.cat([enc1, dec2], dim=1))
dec4 = self.dec4(dec3)
dec0 = self.scse5(self.dec0(torch.cat([enc4, center], dim=1)))
dec1 = self.scse6(self.dec1(torch.cat([enc3, dec0], dim=1)))
dec2 = self.scse7(self.dec2(torch.cat([enc2, dec1], dim=1)))
dec3 = self.scse8(self.dec3(torch.cat([enc1, dec2], dim=1)))
dec4 = self.scse9(self.dec4(dec3))
dec5 = self.dec5(dec4)

return self.final(dec5)

0 comments on commit 57da3ec

Please sign in to comment.