[f2cb69]: / CellGraph / layers_custom.py

Download this file

133 lines (103 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import pdb
def down_shift(x, pad=None):
# Pytorch ordering
xs = [int(y) for y in x.size()]
# when downshifting, the last row is removed
x = x[:, :, :xs[2] - 1, :]
# padding left, padding right, padding top, padding bottom
pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad
return pad(x)
class MaskedConv2d(nn.Conv2d):
def __init__(self, c_in, c_out, k_size, stride, pad, use_down_shift=False):
super(MaskedConv2d, self).__init__(
c_in, c_out, k_size, stride, pad, bias=False)
ch_out, ch_in, height, width = self.weight.size()
# Mask
# -------------------------------------
# | 1 1 1 1 1 |
# | 1 1 1 1 1 |
# | 1 1 1 1 1 | H // 2
# | 0 0 0 0 0 | H // 2 + 1
# | 0 0 0 0 0 |
# -------------------------------------
# index 0 1 W//2 W//2+1
mask = torch.ones(ch_out, ch_in, height, width)
mask[:, :, height // 2 + 1:] = 0
self.register_buffer('mask', mask)
self.use_down_shift = use_down_shift
def forward(self, x):
self.weight.data *= self.mask
if self.use_down_shift:
x = down_shift(x)
return super(MaskedConv2d, self).forward(x)
def maskConv0(c_in=3, c_out=256, k_size=7, stride=1, pad=3):
"""2D Masked Convolution first layer"""
return nn.Sequential(
MaskedConv2d(c_in, c_out * 2, k_size, stride, pad, use_down_shift=True),
nn.BatchNorm2d(c_out * 2),
Gate()
)
class Gate(nn.Module):
def __init__(self):
super(Gate, self).__init__()
def forward(self, x):
# gated activation
xf, xg = torch.chunk(x, 2, dim=1)
f = torch.tanh(xf)
g = torch.sigmoid(xg)
return f * g
class MaskConvBlock(nn.Module):
def __init__(self, h=128, k_size=3, stride=1, pad=1):
"""1x1 Conv + 2D Masked Convolution (type B) + 1x1 Conv"""
super(MaskConvBlock, self).__init__()
self.net = nn.Sequential(
MaskedConv2d(h, 2 * h, k_size, stride, pad),
nn.BatchNorm2d(2 * h),
Gate()
)
def forward(self, x):
"""Residual connection"""
return self.net(x) + x
if __name__ == '__main__':
def conv(x, kernel):
return nn.functional.conv2d(x, kernel, padding=1)
x = torch.ones((1, 1, 5, 5)) * 0.1
x[:,:,1,0] = 1000
print("blindspot experiment")
normal_kernel = torch.ones(1, 1, 3, 3)
mask_kernel = torch.zeros(1, 1, 3, 3)
mask_kernel[:,:,0,:] = 1
mask_b = mask_kernel.clone()
mask_b[:,:,1,1] = 1
# mask_kernel[:,:,1,1] = 1
print("unmasked kernel:", "\n",normal_kernel.squeeze(), "\n")
print("masked kernel:", "\n", mask_kernel.squeeze(), "\n")
print("normal conv")
print("orig image", "\n", x.squeeze(), "\n")
y = conv(x, normal_kernel)
print(y[:,0, :,:], "\n")
y = conv(y, normal_kernel)
print(y[:,0, :,:], "\n")
print("with mask")
print("orig image", "\n", x.squeeze(), "\n")
y = conv(x, mask_kernel)
print(y[:,0, :,:], "\n")
y = conv(y, mask_b)
print(y[:,0, :,:], "\n")
y = conv(y, mask_b)
print(y[:,0, :,:],"\n")
print("with down_shift")
print("orig image", x.squeeze(), "\n")
c_kernel = mask_kernel
c_kernel[:,:,1,:] = 1
print("custom kernel:", "\n", c_kernel.squeeze(), "\n")
y = conv(down_shift(x), c_kernel)
print(y[:,0, :,:],"\n")
y = conv(y, c_kernel)
print(y[:,0, :,:],"\n")
y = conv(y, c_kernel)
print(y[:,0, :,:],"\n")
y = conv(y, c_kernel)
print(y[:,0, :,:],"\n")