Diff of /src/scMDC_batch.py [000000] .. [ac720d]

Switch to unified view

a b/src/scMDC_batch.py
1
from sklearn.metrics.pairwise import paired_distances
2
from sklearn.decomposition import PCA
3
from sklearn import metrics
4
from sklearn.cluster import KMeans
5
import torch
6
import torch.nn as nn
7
from torch.autograd import Variable
8
from torch.nn import Parameter
9
import torch.nn.functional as F
10
import torch.optim as optim
11
from torch.utils.data import DataLoader, TensorDataset
12
from layers import NBLoss, ZINBLoss, MeanAct, DispAct
13
import numpy as np
14
15
import math, os
16
17
from utils import torch_PCA
18
19
from preprocess import read_dataset, normalize
20
import scanpy as sc
21
22
def buildNetwork1(layers, type, activation="relu"):
23
    net = []
24
    for i in range(1, len(layers)):
25
        net.append(nn.Linear(layers[i-1], layers[i]))
26
        if type=="encode" and i==len(layers)-1:
27
            break
28
        if activation=="relu":
29
            net.append(nn.ReLU())
30
        elif activation=="sigmoid":
31
            net.append(nn.Sigmoid())
32
        elif activation=="elu":
33
            net.append(nn.ELU())
34
    return nn.Sequential(*net)
35
36
def buildNetwork2(layers, type, activation="relu"):
37
    net = []
38
    for i in range(1, len(layers)):
39
        net.append(nn.Linear(layers[i-1], layers[i]))
40
        net.append(nn.BatchNorm1d(layers[i], affine=True))
41
        if activation=="relu":
42
            net.append(nn.ReLU())
43
        elif activation=="selu":
44
            net.append(nn.SELU())
45
        elif activation=="sigmoid":
46
            net.append(nn.Sigmoid())
47
        elif activation=="elu":
48
            net.append(nn.ELU())
49
    return nn.Sequential(*net)
50
51
class scMultiClusterBatch(nn.Module):
52
    def __init__(self, input_dim1, input_dim2, n_batch,
53
            encodeLayer=[], decodeLayer1=[], decodeLayer2=[], tau=1., t=10, device = "cuda",
54
            activation="elu", sigma1=2.5, sigma2=.1, alpha=1., gamma=1., phi1=0.0001, phi2=0.0001, cutoff = 0.5):
55
        super(scMultiClusterBatch, self).__init__()
56
        self.tau=tau
57
        self.input_dim1 = input_dim1
58
        self.input_dim2 = input_dim2
59
        self.cutoff = cutoff
60
        self.activation = activation
61
        self.sigma1 = sigma1
62
        self.sigma2 = sigma2
63
        self.alpha = alpha
64
        self.gamma = gamma
65
        self.phi1 = phi1
66
        self.phi2 = phi2
67
        self.t=t
68
        self.device = device
69
        self.encoder = buildNetwork2([input_dim1+input_dim2+n_batch]+encodeLayer, type="encode", activation=activation)
70
        self.decoder1 = buildNetwork2([decodeLayer1[0]+n_batch]+decodeLayer1[1:], type="decode", activation=activation)
71
        self.decoder2 = buildNetwork2([decodeLayer2[0]+n_batch]+decodeLayer2[1:], type="decode", activation=activation)       
72
        self.dec_mean1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), MeanAct())
73
        self.dec_disp1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), DispAct())
74
        self.dec_mean2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), MeanAct())
75
        self.dec_disp2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), DispAct())
76
        self.dec_pi1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), nn.Sigmoid())
77
        self.dec_pi2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), nn.Sigmoid())
78
        self.zinb_loss = ZINBLoss()
79
        self.NBLoss = NBLoss()
80
        self.mse = nn.MSELoss()
81
        self.z_dim = encodeLayer[-1]
82
83
    def save_model(self, path):
84
        torch.save(self.state_dict(), path)
85
86
    def load_model(self, path):
87
        pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage)
88
        model_dict = self.state_dict()
89
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
90
        model_dict.update(pretrained_dict) 
91
        self.load_state_dict(model_dict)
92
93
    def soft_assign(self, z):
94
        q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha)
95
        q = q**((self.alpha+1.0)/2.0)
96
        q = (q.t() / torch.sum(q, dim=1)).t()
97
        return q
98
        
99
    def cal_latent(self, z):
100
        sum_y = torch.sum(torch.square(z), dim=1)
101
        num = -2.0 * torch.matmul(z, z.t()) + torch.reshape(sum_y, [-1, 1]) + sum_y
102
        num = num / self.alpha
103
        num = torch.pow(1.0 + num, -(self.alpha + 1.0) / 2.0)
104
        zerodiag_num = num - torch.diag(torch.diag(num))
105
        latent_p = (zerodiag_num.t() / torch.sum(zerodiag_num, dim=1)).t()
106
        return num, latent_p
107
     
108
    def target_distribution(self, q):
109
        p = q**2 / q.sum(0)
110
        return (p.t() / p.sum(1)).t()
111
        
112
    def kmeans_loss(self, z):
113
        dist1 = self.tau * torch.sum(torch.square(z.unsqueeze(1) - self.mu), dim=2)
114
        temp_dist1 = dist1 - torch.reshape(torch.mean(dist1, dim=1), [-1, 1])
115
        q = torch.exp(-temp_dist1)
116
        q = (q.t() / torch.sum(q, dim=1)).t()
117
        q = torch.pow(q, 2)
118
        q = (q.t() / torch.sum(q, dim=1)).t()
119
        dist2 = dist1 * q
120
        return dist1, torch.mean(torch.sum(dist2, dim=1))
121
        
122
    def forward(self, x1, x2, b):
123
        x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
124
        h = self.encoder(torch.cat([x, b], dim=-1))
125
        h = torch.cat([h, b], dim=-1)
126
127
        h1 = self.decoder1(h)
128
        mean1 = self.dec_mean1(h1)
129
        disp1 = self.dec_disp1(h1)
130
        pi1 = self.dec_pi1(h1)
131
132
        h2 = self.decoder2(h)
133
        mean2 = self.dec_mean2(h2)
134
        disp2 = self.dec_disp2(h2)
135
        pi2 = self.dec_pi2(h2)
136
137
        x0 = torch.cat([x1, x2], dim=-1)
138
        h0 = self.encoder(torch.cat([x0, b], dim=-1))
139
        q = self.soft_assign(h0)
140
        num, lq = self.cal_latent(h0)
141
        return h0, q, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
142
143
    def forwardAE(self, x1, x2, b):
144
        x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
145
        h = self.encoder(torch.cat([x, b], dim=-1))
146
        h = torch.cat([h, b], dim=-1)
147
148
        h1 = self.decoder1(h)
149
        mean1 = self.dec_mean1(h1)
150
        disp1 = self.dec_disp1(h1)
151
        pi1 = self.dec_pi1(h1)
152
        
153
        h2 = self.decoder2(h)
154
        mean2 = self.dec_mean2(h2)
155
        disp2 = self.dec_disp2(h2)
156
        pi2 = self.dec_pi2(h2)
157
158
        x0 = torch.cat([x1, x2], dim=-1)
159
        h0 = self.encoder(torch.cat([x0, b], dim=-1))
160
        num, lq = self.cal_latent(h0)
161
        return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
162
        
163
    def encodeBatch(self, X1, X2, B, batch_size=256):
164
        use_cuda = torch.cuda.is_available()
165
        if use_cuda:
166
            self.to(self.device)
167
        encoded = []
168
        self.eval()
169
        num = X1.shape[0]
170
        num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
171
        for batch_idx in range(num_batch):
172
            x1batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
173
            x2batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
174
            b_batch = B[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
175
            inputs1 = Variable(x1batch).to(self.device)
176
            inputs2 = Variable(x2batch).to(self.device)
177
            b_tensor = Variable(b_batch).to(self.device)
178
            z,_,_,_,_,_,_,_,_ = self.forwardAE(inputs1.float(), inputs2.float(), b_tensor.float())
179
            encoded.append(z.data)
180
181
        encoded = torch.cat(encoded, dim=0)
182
        return encoded
183
184
    def cluster_loss(self, p, q):
185
        def kld(target, pred):
186
            return torch.mean(torch.sum(target*torch.log(target/(pred+1e-6)), dim=-1))
187
        kldloss = kld(p, q)
188
        return kldloss
189
190
    def kldloss(self, p, q):
191
        c1 = -torch.sum(p * torch.log(q), dim=-1)
192
        c2 = -torch.sum(p * torch.log(p), dim=-1)
193
        return torch.mean(c1 - c2)
194
195
    def SDis_func(self, x, y):
196
        return torch.sum(torch.square(x - y), dim=1)
197
198
    def pretrain_autoencoder(self, X1, X_raw1, sf1, X2, X_raw2, sf2, B,
199
            batch_size=256, lr=0.001, epochs=400, ae_save=True, ae_weights='AE_weights.pth.tar'):
200
        num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
201
        dataset = TensorDataset(torch.Tensor(X1), torch.Tensor(X_raw1), torch.Tensor(sf1), torch.Tensor(X2), torch.Tensor(X_raw2), torch.Tensor(sf2), torch.Tensor(B))
202
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
203
        print("Pretraining stage")
204
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True)
205
        counts = 0
206
        for epoch in range(epochs):
207
            loss_val = 0
208
            recon_loss1_val = 0
209
            recon_loss2_val = 0
210
            kl_loss_val = 0
211
            for batch_idx, (x1_batch, x_raw1_batch, sf1_batch, x2_batch, x_raw2_batch, sf2_batch, b_batch) in enumerate(dataloader):
212
                x1_tensor = Variable(x1_batch).to(self.device)
213
                x_raw1_tensor = Variable(x_raw1_batch).to(self.device)
214
                sf1_tensor = Variable(sf1_batch).to(self.device)
215
                x2_tensor = Variable(x2_batch).to(self.device)
216
                x_raw2_tensor = Variable(x_raw2_batch).to(self.device)
217
                sf2_tensor = Variable(sf2_batch).to(self.device)
218
                b_tensor = Variable(b_batch).to(self.device)
219
                zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forwardAE(x1_tensor, x2_tensor, b_tensor)
220
                #recon_loss1 = self.mse(mean1_tensor, x1_tensor)
221
                recon_loss1 = self.zinb_loss(x=x_raw1_tensor, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sf1_tensor)
222
                #recon_loss2 = self.mse(mean2_tensor, x2_tensor)
223
                recon_loss2 = self.zinb_loss(x=x_raw2_tensor, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sf2_tensor)
224
                lpbatch = self.target_distribution(lqbatch)
225
                lqbatch = lqbatch + torch.diag(torch.diag(z_num))
226
                lpbatch = lpbatch + torch.diag(torch.diag(z_num))
227
                kl_loss = self.kldloss(lpbatch, lqbatch) 
228
                if epoch+1 >= epochs * self.cutoff:
229
                   loss = recon_loss1 + recon_loss2 + kl_loss * self.phi1
230
                else:
231
                   loss = recon_loss1 + recon_loss2 #+ kl_loss
232
                optimizer.zero_grad()
233
                loss.backward()
234
                optimizer.step()
235
236
                loss_val += loss.item() * len(x1_batch)
237
                recon_loss1_val += recon_loss1.item() * len(x1_batch)
238
                recon_loss2_val += recon_loss2.item() * len(x1_batch)
239
                if epoch+1 >= epochs * self.cutoff:
240
                    kl_loss_val += kl_loss.item() * len(x1_batch)
241
242
            loss_val = loss_val/X1.shape[0]
243
            recon_loss1_val = loss_val/X1.shape[0]
244
            recon_loss2_val = recon_loss2_val/X1.shape[0]
245
            kl_loss_val = kl_loss_val/X1.shape[0]
246
            if epoch%self.t == 0:
247
               print('Pretrain epoch {}, Total loss:{:.6f}, ZINB loss:{:.6f}, NB loss:{:.6f}, KL loss:{:.6f}'.format(epoch+1, loss_val, recon_loss1_val, recon_loss2_val, kl_loss_val))
248
249
        if ae_save:
250
            torch.save({'ae_state_dict': self.state_dict(),
251
                    'optimizer_state_dict': optimizer.state_dict()}, ae_weights)
252
253
    def save_checkpoint(self, state, index, filename):
254
        newfilename = os.path.join(filename, 'FTcheckpoint_%d.pth.tar' % index)
255
        torch.save(state, newfilename)
256
257
    def fit(self, X1, X_raw1, sf1, X2, X_raw2, sf2, B, y=None, lr=1., n_clusters = 4,
258
            batch_size=256, num_epochs=10, update_interval=1, tol=1e-3, save_dir=""):
259
        '''X: tensor data'''
260
        use_cuda = torch.cuda.is_available()
261
        if use_cuda:
262
            self.to(self.device)
263
        print("Clustering stage")
264
        X1 = torch.tensor(X1).to(self.device)
265
        X_raw1 = torch.tensor(X_raw1).to(self.device)
266
        sf1 = torch.tensor(sf1).to(self.device)
267
        X2 = torch.tensor(X2).to(self.device)
268
        X_raw2 = torch.tensor(X_raw2).to(self.device)
269
        sf2 = torch.tensor(sf2).to(self.device)
270
        B = torch.tensor(B).to(self.device)
271
        self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim), requires_grad=True)
272
        optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95)
273
        #optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=0.001)
274
             
275
        print("Initializing cluster centers with kmeans.")
276
        kmeans = KMeans(n_clusters, n_init=20)
277
        Zdata = self.encodeBatch(X1, X2, B, batch_size=batch_size)
278
        #latent
279
        self.y_pred = kmeans.fit_predict(Zdata.data.cpu().numpy())
280
        self.y_pred_last = self.y_pred
281
        self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_))
282
        if y is not None:
283
            ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
284
            nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
285
            ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
286
            print('Initializing k-means: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
287
        
288
        self.train()
289
        num = X1.shape[0]
290
        num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
291
292
        final_nmi, final_ari, final_epoch = 0, 0, 0
293
294
        for epoch in range(num_epochs):
295
            if epoch%update_interval == 0:
296
                # update the targe distribution p
297
                Zdata = self.encodeBatch(X1, X2, B, batch_size=batch_size)
298
                
299
                # evalute the clustering performance
300
                dist, _ = self.kmeans_loss(Zdata)
301
                self.y_pred = torch.argmin(dist, dim=1).data.cpu().numpy()
302
303
                if y is not None:
304
                    #acc2 = np.round(cluster_acc(y, self.y_pred), 5)
305
                    final_ami = ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
306
                    final_nmi = nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
307
                    final_ari = ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
308
                    final_epoch = epoch+1
309
                    print('Clustering   %d: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (epoch+1, ami, nmi, ari))
310
311
                # check stop criterion
312
                delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num
313
                self.y_pred_last = self.y_pred
314
                if epoch>0 and delta_label < tol:
315
                    print('delta_label ', delta_label, '< tol ', tol)
316
                    print("Reach tolerance threshold. Stopping training.")
317
                    break
318
                
319
                # save current model
320
                # if (epoch>0 and delta_label < tol) or epoch%10 == 0:
321
                    # self.save_checkpoint({'epoch': epoch+1,
322
                            # 'state_dict': self.state_dict(),
323
                            # 'mu': self.mu,
324
                            # 'y_pred': self.y_pred,
325
                            # 'y_pred_last': self.y_pred_last,
326
                            # 'y': y
327
                            # }, epoch+1, filename=save_dir)
328
                
329
            # train 1 epoch for clustering loss
330
            train_loss = 0.0
331
            recon_loss1_val = 0.0
332
            recon_loss2_val = 0.0
333
            recon_loss_latent_val = 0.0
334
            cluster_loss_val = 0.0
335
            kl_loss_val = 0.0
336
            for batch_idx in range(num_batch):
337
                x1_batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
338
                x_raw1_batch = X_raw1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
339
                sf1_batch = sf1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
340
                x2_batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
341
                x_raw2_batch = X_raw2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
342
                sf2_batch = sf2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
343
                b_batch = B[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
344
                optimizer.zero_grad()
345
                inputs1 = Variable(x1_batch)
346
                rawinputs1 = Variable(x_raw1_batch)
347
                sfinputs1 = Variable(sf1_batch)
348
                inputs2 = Variable(x2_batch)
349
                rawinputs2 = Variable(x_raw2_batch)
350
                sfinputs2 = Variable(sf2_batch)
351
352
                zbatch, qbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forward(inputs1.float(), inputs2.float(), b_batch.float())
353
                
354
                _, cluster_loss = self.kmeans_loss(zbatch)
355
                recon_loss1 = self.zinb_loss(x=rawinputs1, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sfinputs1)
356
                recon_loss2 = self.zinb_loss(x=rawinputs2, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sfinputs2)
357
                target2 = self.target_distribution(lqbatch)
358
                lqbatch = lqbatch + torch.diag(torch.diag(z_num))
359
                target2 = target2 + torch.diag(torch.diag(z_num))
360
                kl_loss = self.kldloss(target2, lqbatch)
361
                loss = cluster_loss * self.gamma + kl_loss * self.phi2 + recon_loss1 + recon_loss2
362
                loss.backward()
363
                torch.nn.utils.clip_grad_norm_(self.mu, 1)
364
                optimizer.step()
365
                cluster_loss_val += cluster_loss.data * len(inputs1)
366
                recon_loss1_val += recon_loss1.data * len(inputs1)
367
                recon_loss2_val += recon_loss2.data * len(inputs2)
368
                kl_loss_val += kl_loss.data * len(inputs1)
369
                loss_val = cluster_loss_val + recon_loss1_val + recon_loss2_val + kl_loss_val
370
371
            if epoch%self.t == 0:
372
                print("#Epoch %d: Total: %.6f Clustering Loss: %.6f ZINB Loss: %.6f ZINB Loss2: %.6f KL Loss: %.6f" % (
373
                     epoch + 1, loss_val / num, cluster_loss_val / num, recon_loss1_val / num, recon_loss2_val / num, kl_loss_val / num))
374
375
        return self.y_pred, final_epoch