Diff of /CellGraph/model.py [000000] .. [2095ed]

Switch to unified view

a b/CellGraph/model.py
1
import numpy as np
2
import torch
3
import torch.nn as nn
4
import torch.nn.functional as F
5
from torch.autograd import Variable
6
from resnet_custom import *
7
import pdb
8
import math
9
from pixelcnn import MaskCNN
10
11
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
13
14
15
def initialize_weights(module):
16
    """
17
    args:
18
        module: any pytorch module with trainable parameters
19
    """
20
21
    for m in module.modules():
22
        if isinstance(m, nn.Conv2d):
23
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
24
            if m.bias is not None:
25
                m.bias.data.zero_()
26
27
        # if isinstance(m, nn.Linear):
28
        #   nn.init.xavier_normal_(m.weight)
29
        #   m.bias.data.zero_()
30
31
        if isinstance(m, nn.Linear):
32
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
33
            m.bias.data.zero_()
34
        
35
        elif isinstance(m, nn.BatchNorm2d):
36
            nn.init.constant_(m.weight, 1)
37
            nn.init.constant_(m.bias, 0)
38
39
40
class CPC_model(nn.Module):
41
    def __init__(self, input_size = 1024, hidden_size = 128, k = 3, ln = False):
42
        """
43
        args:
44
            input_size: input size to autoregresser (encoding size)
45
            hidden_size: number of hidden units in MaskedCNN
46
            num_layers: number of hidden layers in MaskedCNN
47
            k: prediction length
48
        """
49
        super(CPC_model, self).__init__()
50
        
51
        ### Settings
52
        self.seq_len = 49 # 7 x 7 grid of overlapping 64 x 64 patches extracted from each 256 x 256 image
53
        self.k = k 
54
        self.input_size = input_size
55
        self.hidden_size=hidden_size
56
57
58
        ### Networks
59
        if ln:
60
            self.encoder = resnet50_ln(pretrained=False)
61
        else:
62
            self.encoder = resnet50(pretrained=False)
63
        self.reg = MaskCNN(n_channel=self.input_size, h=self.hidden_size)
64
        network_pred = [nn.Linear(input_size, input_size) for i in range(self.k)] #use an indepdent linear layer to predict each future row 
65
        self.network_pred= nn.ModuleList(network_pred)
66
        
67
        # initialize linear network and context network
68
        initialize_weights(self.network_pred)
69
        initialize_weights(self.reg)
70
71
72
        ### Activation functions
73
        self.softmax  = nn.Softmax(dim=1)
74
        self.lsoftmax = nn.LogSoftmax(dim=1)
75
76
    def forward(self, x):
77
        # input = [bs * 7 * 7, 3, 64, 64]
78
        
79
        # compute batch_size
80
        bs = x.size(0) // (self.seq_len)
81
82
        rows = int(math.sqrt(self.seq_len))
83
        cols = int(math.sqrt(self.seq_len))
84
85
        # compute latent representation for each patch
86
        z = self.encoder(x)
87
        # z.shape: [bs * 7 * 7, 1024]
88
        
89
        # reshape z into feature grid: [bs, 7, 7, 1024]
90
        z = z.contiguous().view(bs, rows, cols, self.input_size)
91
        
92
        device = z.device
93
94
        #randomly draw a row to predict what is k rows below it, using information in current row and above
95
        if self.training:
96
            pred_id = torch.randint(rows - self.k, size=(1,)).long() #low is 0, high is 3 (predicts row 4, 5, 6)
97
        
98
        else:
99
            pred_id = torch.tensor([3]).long()
100
        
101
        # feature predictions for the next k rows  e.g.  pred[i] is [bs * cols, 1024] for i in k
102
        pred = [torch.empty(bs * cols, self.input_size).float().to(device) for i in range(self.k)]
103
104
        # ground truth encodings for the next k rows e.g. encode_samples[i] is [bs * cols, 1024] for i in k
105
        encode_samples = [torch.empty(bs * cols, self.input_size).float().to(device) for i in range(self.k)]
106
107
        for i in np.arange(self.k):
108
            # add ground truth encodings
109
            start_row = pred_id.item()+i+1
110
            encode_samples[i] = z[:,start_row, :, :].contiguous().view(bs * cols, self.input_size)
111
112
        # reshape feature grid to channel first (required by Pytorch convolution convention)
113
        z = z.permute(0, 3, 1, 2) 
114
        # z.shape: from [bs, 7, 7, 1024] --> [bs, 1024, 7, 7]
115
        
116
        # apply aggregation to compute context 
117
        output = self.reg(z)
118
        # reg is fully convolutional --> output size is [bs, 1024, 7, 7]
119
120
        output = output.permute(0, 2, 3, 1) # reshape back to feature grid
121
        # output.shape: [bs, row, col, 1024]
122
        
123
        # context for each patch in the row
124
        c_t = output[:,pred_id + 1,:, :]
125
        # c_t.shape: [bs, 1, 7, 1024]
126
127
        # reshape for linear classification:
128
        c_t = c_t.contiguous().view(bs * cols, self.input_size)
129
        # c_t.shape: [bs * cols, 1024]
130
131
        # linear prediction: Wk*c_t
132
        for i in np.arange(0, self.k):
133
            if type(self.network_pred) == nn.DataParallel:
134
                pred[i] = self.network_pred.module[i](c_t)
135
            
136
            else:
137
                pred[i] = self.network_pred[i](c_t)  #e.g. size [bs * cols, 1024]
138
139
        nce = 0 # average over prediction length, cols, and batch 
140
        accuracy = np.zeros((self.k,))
141
        
142
        for i in np.arange(0, self.k):
143
            """
144
            goal: can network correctly match predicted features with ground truth features among negative targets 
145
            i.e. match z_i+k,j with W_k * c_i,j
146
            postivie target: patch with the correct groundtruth encoding
147
            negative targets: patches with wrong groundtruth encodings (sampled from other patches in the same image, or other images in the minibatch)
148
149
            1) dot product for each k to obtain raw prediction logits 
150
            total = (a_ij) = [bs * col, bs * col], where a_ij is the logit of ith patch prediction matching jth patch encoding
151
            
152
            2) apply softmax along each row to get probability that ith patch prediction matches jth patch encoding 
153
            we want ith patch prediction to correctly match ith patch encoding, therefore target has 1s along diagnol, and 0s off diagnol
154
155
            3) we take the argmax along softmaxed rows to get the patch prediction for the ith patch, this value should be i
156
157
            4) compute nce loss as the cross-entropy of classifying the positive sample correctly (sum of logsoftmax along diagnol)
158
159
            5) normalize loss by batchsize and k and number of patches in a row
160
            
161
            """
162
            total = torch.mm(pred[i], torch.transpose(encode_samples[i],0,1)) # e.g. size [bs * col, bs * col]
163
164
            accuracy[i] = torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=1), torch.arange(0, bs * cols).to(device))).item() 
165
            accuracy[i] /= 1. * (bs * cols) 
166
            
167
            nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
168
        
169
        nce /= -1. * bs * cols * self.k
170
        # accuracy = 1.*correct.item() / (bs * cols * self.k)
171
        
172
        return nce, np.array(accuracy)
173
174
175
# crop data into 64 by 64 with 32 overlap 
176
def cropdata(data, num_channels=3, kernel_size = 64, stride = 32):
177
    if len(data.shape) == 3:
178
        data = data.unsqueeze(0)
179
180
    data = data.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
181
    data = data.permute(0,2,3,1,4,5)
182
    data = data.contiguous().view(-1, num_channels, kernel_size, kernel_size)
183
    return data
184
185
if __name__ == '__main__':
186
    torch.set_printoptions(threshold=1e6)
187
    x = torch.rand(2, 3, 256, 256)
188
    x = cropdata(x)
189
    print(x.shape)
190
    model = CPC_model(1024, 256)
191
    nce, accuracy = model(x)
192
193