Diff of /CellGraph/model.py [000000] .. [2095ed]

Switch to side-by-side view

--- a
+++ b/CellGraph/model.py
@@ -0,0 +1,193 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from resnet_custom import *
+import pdb
+import math
+from pixelcnn import MaskCNN
+
+device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+
+def initialize_weights(module):
+	"""
+	args:
+		module: any pytorch module with trainable parameters
+	"""
+
+	for m in module.modules():
+		if isinstance(m, nn.Conv2d):
+			nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
+			if m.bias is not None:
+				m.bias.data.zero_()
+
+		# if isinstance(m, nn.Linear):
+		# 	nn.init.xavier_normal_(m.weight)
+		# 	m.bias.data.zero_()
+
+		if isinstance(m, nn.Linear):
+			nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+			m.bias.data.zero_()
+		
+		elif isinstance(m, nn.BatchNorm2d):
+			nn.init.constant_(m.weight, 1)
+			nn.init.constant_(m.bias, 0)
+
+
+class CPC_model(nn.Module):
+	def __init__(self, input_size = 1024, hidden_size = 128, k = 3, ln = False):
+		"""
+		args:
+			input_size: input size to autoregresser (encoding size)
+			hidden_size: number of hidden units in MaskedCNN
+			num_layers: number of hidden layers in MaskedCNN
+			k: prediction length
+		"""
+		super(CPC_model, self).__init__()
+		
+		### Settings
+		self.seq_len = 49 # 7 x 7 grid of overlapping 64 x 64 patches extracted from each 256 x 256 image
+		self.k = k 
+		self.input_size = input_size
+		self.hidden_size=hidden_size
+
+
+		### Networks
+		if ln:
+			self.encoder = resnet50_ln(pretrained=False)
+		else:
+			self.encoder = resnet50(pretrained=False)
+		self.reg = MaskCNN(n_channel=self.input_size, h=self.hidden_size)
+		network_pred = [nn.Linear(input_size, input_size) for i in range(self.k)] #use an indepdent linear layer to predict each future row	
+		self.network_pred= nn.ModuleList(network_pred)
+		
+		# initialize linear network and context network
+		initialize_weights(self.network_pred)
+		initialize_weights(self.reg)
+
+
+		### Activation functions
+		self.softmax  = nn.Softmax(dim=1)
+		self.lsoftmax = nn.LogSoftmax(dim=1)
+
+	def forward(self, x):
+		# input = [bs * 7 * 7, 3, 64, 64]
+		
+		# compute batch_size
+		bs = x.size(0) // (self.seq_len)
+
+		rows = int(math.sqrt(self.seq_len))
+		cols = int(math.sqrt(self.seq_len))
+
+		# compute latent representation for each patch
+		z = self.encoder(x)
+		# z.shape: [bs * 7 * 7, 1024]
+		
+		# reshape z into feature grid: [bs, 7, 7, 1024]
+		z = z.contiguous().view(bs, rows, cols, self.input_size)
+		
+		device = z.device
+
+		#randomly draw a row to predict what is k rows below it, using information in current row and above
+		if self.training:
+			pred_id = torch.randint(rows - self.k, size=(1,)).long() #low is 0, high is 3 (predicts row 4, 5, 6)
+		
+		else:
+			pred_id = torch.tensor([3]).long()
+		
+		# feature predictions for the next k rows  e.g.  pred[i] is [bs * cols, 1024] for i in k
+		pred = [torch.empty(bs * cols, self.input_size).float().to(device) for i in range(self.k)]
+
+		# ground truth encodings for the next k rows e.g. encode_samples[i] is [bs * cols, 1024] for i in k
+		encode_samples = [torch.empty(bs * cols, self.input_size).float().to(device) for i in range(self.k)]
+
+		for i in np.arange(self.k):
+			# add ground truth encodings
+			start_row = pred_id.item()+i+1
+			encode_samples[i] = z[:,start_row, :, :].contiguous().view(bs * cols, self.input_size)
+
+		# reshape feature grid to channel first (required by Pytorch convolution convention)
+		z = z.permute(0, 3, 1, 2) 
+		# z.shape: from [bs, 7, 7, 1024] --> [bs, 1024, 7, 7]
+		
+		# apply aggregation to compute context 
+		output = self.reg(z)
+		# reg is fully convolutional --> output size is [bs, 1024, 7, 7]
+
+		output = output.permute(0, 2, 3, 1) # reshape back to feature grid
+		# output.shape: [bs, row, col, 1024]
+		
+		# context for each patch in the row
+		c_t = output[:,pred_id + 1,:, :]
+		# c_t.shape: [bs, 1, 7, 1024]
+
+		# reshape for linear classification:
+		c_t = c_t.contiguous().view(bs * cols, self.input_size)
+		# c_t.shape: [bs * cols, 1024]
+
+		# linear prediction: Wk*c_t
+		for i in np.arange(0, self.k):
+			if type(self.network_pred) == nn.DataParallel:
+				pred[i] = self.network_pred.module[i](c_t)
+			
+			else:
+				pred[i] = self.network_pred[i](c_t)  #e.g. size [bs * cols, 1024]
+
+		nce = 0 # average over prediction length, cols, and batch 
+		accuracy = np.zeros((self.k,))
+		
+		for i in np.arange(0, self.k):
+			"""
+			goal: can network correctly match predicted features with ground truth features among negative targets 
+			i.e. match z_i+k,j with W_k * c_i,j
+			postivie target: patch with the correct groundtruth encoding
+			negative targets: patches with wrong groundtruth encodings (sampled from other patches in the same image, or other images in the minibatch)
+
+			1) dot product for each k to obtain raw prediction logits 
+			total = (a_ij) = [bs * col, bs * col], where a_ij is the logit of ith patch prediction matching jth patch encoding
+			
+			2) apply softmax along each row to get probability that ith patch prediction matches jth patch encoding 
+			we want ith patch prediction to correctly match ith patch encoding, therefore target has 1s along diagnol, and 0s off diagnol
+
+			3) we take the argmax along softmaxed rows to get the patch prediction for the ith patch, this value should be i
+
+			4) compute nce loss as the cross-entropy of classifying the positive sample correctly (sum of logsoftmax along diagnol)
+
+			5) normalize loss by batchsize and k and number of patches in a row
+			
+			"""
+			total = torch.mm(pred[i], torch.transpose(encode_samples[i],0,1)) # e.g. size [bs * col, bs * col]
+
+			accuracy[i] = torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=1), torch.arange(0, bs * cols).to(device))).item() 
+			accuracy[i] /= 1. * (bs * cols) 
+			
+			nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
+		
+		nce /= -1. * bs * cols * self.k
+		# accuracy = 1.*correct.item() / (bs * cols * self.k)
+		
+		return nce, np.array(accuracy)
+
+
+# crop data into 64 by 64 with 32 overlap 
+def cropdata(data, num_channels=3, kernel_size = 64, stride = 32):
+	if len(data.shape) == 3:
+		data = data.unsqueeze(0)
+
+	data = data.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
+	data = data.permute(0,2,3,1,4,5)
+	data = data.contiguous().view(-1, num_channels, kernel_size, kernel_size)
+	return data
+
+if __name__ == '__main__':
+	torch.set_printoptions(threshold=1e6)
+	x = torch.rand(2, 3, 256, 256)
+	x = cropdata(x)
+	print(x.shape)
+	model = CPC_model(1024, 256)
+	nce, accuracy = model(x)
+
+