[f2cb69]: / CellGraph / model.py

Download this file

194 lines (141 with data), 6.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from resnet_custom import *
import pdb
import math
from pixelcnn import MaskCNN
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_weights(module):
"""
args:
module: any pytorch module with trainable parameters
"""
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
# if isinstance(m, nn.Linear):
# nn.init.xavier_normal_(m.weight)
# m.bias.data.zero_()
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class CPC_model(nn.Module):
def __init__(self, input_size = 1024, hidden_size = 128, k = 3, ln = False):
"""
args:
input_size: input size to autoregresser (encoding size)
hidden_size: number of hidden units in MaskedCNN
num_layers: number of hidden layers in MaskedCNN
k: prediction length
"""
super(CPC_model, self).__init__()
### Settings
self.seq_len = 49 # 7 x 7 grid of overlapping 64 x 64 patches extracted from each 256 x 256 image
self.k = k
self.input_size = input_size
self.hidden_size=hidden_size
### Networks
if ln:
self.encoder = resnet50_ln(pretrained=False)
else:
self.encoder = resnet50(pretrained=False)
self.reg = MaskCNN(n_channel=self.input_size, h=self.hidden_size)
network_pred = [nn.Linear(input_size, input_size) for i in range(self.k)] #use an indepdent linear layer to predict each future row
self.network_pred= nn.ModuleList(network_pred)
# initialize linear network and context network
initialize_weights(self.network_pred)
initialize_weights(self.reg)
### Activation functions
self.softmax = nn.Softmax(dim=1)
self.lsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x):
# input = [bs * 7 * 7, 3, 64, 64]
# compute batch_size
bs = x.size(0) // (self.seq_len)
rows = int(math.sqrt(self.seq_len))
cols = int(math.sqrt(self.seq_len))
# compute latent representation for each patch
z = self.encoder(x)
# z.shape: [bs * 7 * 7, 1024]
# reshape z into feature grid: [bs, 7, 7, 1024]
z = z.contiguous().view(bs, rows, cols, self.input_size)
device = z.device
#randomly draw a row to predict what is k rows below it, using information in current row and above
if self.training:
pred_id = torch.randint(rows - self.k, size=(1,)).long() #low is 0, high is 3 (predicts row 4, 5, 6)
else:
pred_id = torch.tensor([3]).long()
# feature predictions for the next k rows e.g. pred[i] is [bs * cols, 1024] for i in k
pred = [torch.empty(bs * cols, self.input_size).float().to(device) for i in range(self.k)]
# ground truth encodings for the next k rows e.g. encode_samples[i] is [bs * cols, 1024] for i in k
encode_samples = [torch.empty(bs * cols, self.input_size).float().to(device) for i in range(self.k)]
for i in np.arange(self.k):
# add ground truth encodings
start_row = pred_id.item()+i+1
encode_samples[i] = z[:,start_row, :, :].contiguous().view(bs * cols, self.input_size)
# reshape feature grid to channel first (required by Pytorch convolution convention)
z = z.permute(0, 3, 1, 2)
# z.shape: from [bs, 7, 7, 1024] --> [bs, 1024, 7, 7]
# apply aggregation to compute context
output = self.reg(z)
# reg is fully convolutional --> output size is [bs, 1024, 7, 7]
output = output.permute(0, 2, 3, 1) # reshape back to feature grid
# output.shape: [bs, row, col, 1024]
# context for each patch in the row
c_t = output[:,pred_id + 1,:, :]
# c_t.shape: [bs, 1, 7, 1024]
# reshape for linear classification:
c_t = c_t.contiguous().view(bs * cols, self.input_size)
# c_t.shape: [bs * cols, 1024]
# linear prediction: Wk*c_t
for i in np.arange(0, self.k):
if type(self.network_pred) == nn.DataParallel:
pred[i] = self.network_pred.module[i](c_t)
else:
pred[i] = self.network_pred[i](c_t) #e.g. size [bs * cols, 1024]
nce = 0 # average over prediction length, cols, and batch
accuracy = np.zeros((self.k,))
for i in np.arange(0, self.k):
"""
goal: can network correctly match predicted features with ground truth features among negative targets
i.e. match z_i+k,j with W_k * c_i,j
postivie target: patch with the correct groundtruth encoding
negative targets: patches with wrong groundtruth encodings (sampled from other patches in the same image, or other images in the minibatch)
1) dot product for each k to obtain raw prediction logits
total = (a_ij) = [bs * col, bs * col], where a_ij is the logit of ith patch prediction matching jth patch encoding
2) apply softmax along each row to get probability that ith patch prediction matches jth patch encoding
we want ith patch prediction to correctly match ith patch encoding, therefore target has 1s along diagnol, and 0s off diagnol
3) we take the argmax along softmaxed rows to get the patch prediction for the ith patch, this value should be i
4) compute nce loss as the cross-entropy of classifying the positive sample correctly (sum of logsoftmax along diagnol)
5) normalize loss by batchsize and k and number of patches in a row
"""
total = torch.mm(pred[i], torch.transpose(encode_samples[i],0,1)) # e.g. size [bs * col, bs * col]
accuracy[i] = torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=1), torch.arange(0, bs * cols).to(device))).item()
accuracy[i] /= 1. * (bs * cols)
nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
nce /= -1. * bs * cols * self.k
# accuracy = 1.*correct.item() / (bs * cols * self.k)
return nce, np.array(accuracy)
# crop data into 64 by 64 with 32 overlap
def cropdata(data, num_channels=3, kernel_size = 64, stride = 32):
if len(data.shape) == 3:
data = data.unsqueeze(0)
data = data.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
data = data.permute(0,2,3,1,4,5)
data = data.contiguous().view(-1, num_channels, kernel_size, kernel_size)
return data
if __name__ == '__main__':
torch.set_printoptions(threshold=1e6)
x = torch.rand(2, 3, 256, 256)
x = cropdata(x)
print(x.shape)
model = CPC_model(1024, 256)
nce, accuracy = model(x)