[9f71e2]: / Section 3 Simulate DIMSE / src / networks / RecursiveUNet.py

Download this file

114 lines (95 with data), 5.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Defines the Unet.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck
# recursive implementation of Unet
import torch
from torch import nn
class UNet(nn.Module):
def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d):
# norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UNet, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs,
num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True)
for i in range(1, num_downs):
unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)),
out_channels=initial_filter_size * 2 ** (num_downs-i),
num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size,
num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer,
outermost=True)
self.model = unet_block
def forward(self, x):
return self.model(x)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3,
submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
# downconv
pool = nn.MaxPool2d(2, stride=2)
conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer)
# upconv
conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size)
conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size)
if outermost:
final = nn.Conv2d(out_channels, num_classes, kernel_size=1)
down = [conv1, conv2]
up = [conv3, conv4, final]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(in_channels*2, in_channels,
kernel_size=2, stride=2)
model = [pool, conv1, conv2, upconv]
else:
upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2)
down = [pool, conv1, conv2]
up = [conv3, conv4, upconv]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
@staticmethod
def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d):
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
norm_layer(out_channels),
nn.LeakyReLU(inplace=True))
return layer
@staticmethod
def expand(in_channels, out_channels, kernel_size=3):
layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=1),
nn.LeakyReLU(inplace=True),
)
return layer
@staticmethod
def center_crop(layer, target_width, target_height):
batch_size, n_channels, layer_width, layer_height = layer.size()
xy1 = (layer_width - target_width) // 2
xy2 = (layer_height - target_height) // 2
return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)]
def forward(self, x):
if self.outermost:
return self.model(x)
else:
crop = self.center_crop(self.model(x), x.size()[2], x.size()[3])
return torch.cat([x, crop], 1)