# SPDX-License-Identifier: Apache-2.0
#
# Copyright (C) 2021 Supervisely
#
# This file is part of the Supervisely project and has been taken
# from the Supervisely repository (https://github.com/supervisely/supervisely/blob/master/plugins/nn/unet_v2/src/unet.py).
# It is being redistributed under the Apache License 2.0.
#
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torchvision.models.vgg import vgg16_bn
[docs]class ConvBNAct(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
[docs] def forward(self, inputs):
return self.seq(inputs)
[docs]class Block(nn.Module):
def __init__(self, src_channels, dst_channels):
super().__init__()
self.seq1 = ConvBNAct(src_channels, dst_channels)
self.seq2 = ConvBNAct(dst_channels, dst_channels)
self.seq3 = ConvBNAct(dst_channels, dst_channels)
[docs] def forward(self, x):
result = self.seq1(x)
result = self.seq2(result)
result = self.seq3(result)
return result
[docs]class UNetUp(nn.Module):
def __init__(self, down_channels, right_channels):
super().__init__()
self.bottom_up = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = nn.Conv2d(down_channels, right_channels, kernel_size=1, stride=1)
[docs] def forward(self, left, bottom):
from_bottom = self.bottom_up(bottom)
from_bottom = self.conv(from_bottom)
result = torch.cat([left, from_bottom], 1)
return result
[docs]class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=(3, 3), dilation=2, padding=2
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=(3, 3), dilation=2, padding=2
)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU()
[docs] def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.conv2(self.relu(out))
out = self.bn2(out)
return torch.cat((x, self.relu2(out)), dim=1)
[docs]class UNet(nn.Module):
def __init__(self, encoder_blocks, encoder_channels, n_cls):
self.encoder_channels = encoder_channels
self.depth = len(self.encoder_channels)
assert len(encoder_blocks) == self.depth
super().__init__()
self.encoder_blocks = nn.ModuleList(encoder_blocks)
self.blocks = nn.ModuleList()
# add bottleneck
self.blocks.append(Block(self.encoder_channels[-1], self.encoder_channels[-1]))
self.ups = nn.ModuleList()
for i in range(1, self.depth):
bottom_channels = self.encoder_channels[self.depth - i]
left_channels = self.encoder_channels[self.depth - i - 1]
right_channels = left_channels
self.ups.append(UNetUp(bottom_channels, right_channels))
self.blocks.append(Block(left_channels + right_channels, right_channels))
self.last_conv = nn.Conv2d(encoder_channels[0], n_cls, 1)
# self.dropout = nn.Dropout2d(p=0.1)
self.bottle = Bottleneck(512, 512)
[docs] def forward(self, x):
encoder_outputs = []
for encoder_block in self.encoder_blocks:
x = encoder_block(x)
encoder_outputs.append(x)
x = self.bottle(encoder_outputs[self.depth - 1])
for i in range(self.depth):
if i > 0:
encoder_output = encoder_outputs[self.depth - i - 1]
x = self.ups[i - 1](encoder_output, x)
x = self.blocks[i](x)
# x = self.dropout(x)
x = self.last_conv(x)
return x # no softmax or log_softmax
def _get_encoder_blocks(model):
# last modules (ReLUs) of VGG blocks
layers_last_module_names = ["5", "12", "22", "32", "42"]
result = []
cur_block = nn.Sequential()
for name, child in model.named_children():
if name == "features":
for name2, child2 in child.named_children():
cur_block.add_module(name2, child2)
if name2 in layers_last_module_names:
result.append(cur_block)
cur_block = nn.Sequential()
break
return result
[docs]def construct_unet(n_cls, pretrain=False): # no weights inited
model = vgg16_bn(weights="DEFAULT")
encoder_blocks = _get_encoder_blocks(model)
encoder_channels = [64, 128, 256, 512, 1024] # vgg16 channels
# prev_channels = encoder_channels[-1]
return UNet(encoder_blocks, encoder_channels, n_cls)