Diff of /src/scMDC.py [000000] .. [ac720d]

Switch to unified view

a b/src/scMDC.py
1
from sklearn.metrics.pairwise import paired_distances
2
from sklearn.decomposition import PCA
3
from sklearn import metrics
4
from sklearn.cluster import KMeans
5
import torch
6
import torch.nn as nn
7
from torch.autograd import Variable
8
from torch.nn import Parameter
9
import torch.nn.functional as F
10
import torch.optim as optim
11
from torch.utils.data import DataLoader, TensorDataset
12
from layers import NBLoss, ZINBLoss, MeanAct, DispAct
13
import numpy as np
14
15
import math, os
16
17
def buildNetwork2(layers, type, activation="relu"):
18
    net = []
19
    for i in range(1, len(layers)):
20
        net.append(nn.Linear(layers[i-1], layers[i]))
21
        net.append(nn.BatchNorm1d(layers[i], affine=True))
22
        if activation=="relu":
23
            net.append(nn.ReLU())
24
        elif activation=="selu":
25
            net.append(nn.SELU())
26
        elif activation=="sigmoid":
27
            net.append(nn.Sigmoid())
28
        elif activation=="elu":
29
            net.append(nn.ELU())
30
    return nn.Sequential(*net)
31
32
class scMultiCluster(nn.Module):
33
    def __init__(self, input_dim1, input_dim2,
34
            encodeLayer=[], decodeLayer1=[], decodeLayer2=[], tau=1., t=10, device="cuda",
35
            activation="elu", sigma1=2.5, sigma2=.1, alpha=1., gamma=1., phi1=0.0001, phi2=0.0001, cutoff = 0.5):
36
        super(scMultiCluster, self).__init__()
37
        self.tau=tau
38
        self.input_dim1 = input_dim1
39
        self.input_dim2 = input_dim2
40
        self.cutoff = cutoff
41
        self.activation = activation
42
        self.sigma1 = sigma1
43
        self.sigma2 = sigma2
44
        self.alpha = alpha
45
        self.gamma = gamma
46
        self.phi1 = phi1
47
        self.phi2 = phi2
48
        self.t = t
49
        self.device = device
50
        self.encoder = buildNetwork2([input_dim1+input_dim2]+encodeLayer, type="encode", activation=activation)
51
        self.decoder1 = buildNetwork2(decodeLayer1, type="decode", activation=activation)
52
        self.decoder2 = buildNetwork2(decodeLayer2, type="decode", activation=activation)       
53
        self.dec_mean1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), MeanAct())
54
        self.dec_disp1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), DispAct())
55
        self.dec_mean2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), MeanAct())
56
        self.dec_disp2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), DispAct())
57
        self.dec_pi1 = nn.Sequential(nn.Linear(decodeLayer1[-1], input_dim1), nn.Sigmoid())
58
        self.dec_pi2 = nn.Sequential(nn.Linear(decodeLayer2[-1], input_dim2), nn.Sigmoid())
59
        self.zinb_loss = ZINBLoss()
60
        self.z_dim = encodeLayer[-1]
61
62
    def save_model(self, path):
63
        torch.save(self.state_dict(), path)
64
65
    def load_model(self, path):
66
        pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage)
67
        model_dict = self.state_dict()
68
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
69
        model_dict.update(pretrained_dict) 
70
        self.load_state_dict(model_dict)
71
        
72
    def cal_latent(self, z):
73
        sum_y = torch.sum(torch.square(z), dim=1)
74
        num = -2.0 * torch.matmul(z, z.t()) + torch.reshape(sum_y, [-1, 1]) + sum_y
75
        num = num / self.alpha
76
        num = torch.pow(1.0 + num, -(self.alpha + 1.0) / 2.0)
77
        zerodiag_num = num - torch.diag(torch.diag(num))
78
        latent_p = (zerodiag_num.t() / torch.sum(zerodiag_num, dim=1)).t()
79
        return num, latent_p
80
81
    def kmeans_loss(self, z):
82
        dist1 = self.tau*torch.sum(torch.square(z.unsqueeze(1) - self.mu), dim=2)
83
        temp_dist1 = dist1 - torch.reshape(torch.mean(dist1, dim=1), [-1, 1])
84
        q = torch.exp(-temp_dist1)
85
        q = (q.t() / torch.sum(q, dim=1)).t()
86
        q = torch.pow(q, 2)
87
        q = (q.t() / torch.sum(q, dim=1)).t()
88
        dist2 = dist1 * q
89
        return dist1, torch.mean(torch.sum(dist2, dim=1))
90
91
    def target_distribution(self, q):
92
        p = q**2 / q.sum(0)
93
        return (p.t() / p.sum(1)).t()
94
95
    def forward(self, x1, x2):
96
        x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
97
        h = self.encoder(x)
98
99
        h1 = self.decoder1(h)
100
        mean1 = self.dec_mean1(h1)
101
        disp1 = self.dec_disp1(h1)
102
        pi1 = self.dec_pi1(h1)
103
104
        h2 = self.decoder2(h)
105
        mean2 = self.dec_mean2(h2)
106
        disp2 = self.dec_disp2(h2)
107
        pi2 = self.dec_pi2(h2)
108
109
        x0 = torch.cat([x1, x2], dim=-1)
110
        h0 = self.encoder(x0)
111
        num, lq = self.cal_latent(h0)
112
        return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
113
114
    def forwardAE(self, x1, x2):
115
        x = torch.cat([x1+torch.randn_like(x1)*self.sigma1, x2+torch.randn_like(x2)*self.sigma2], dim=-1)
116
        h = self.encoder(x)
117
118
        h1 = self.decoder1(h)
119
        mean1 = self.dec_mean1(h1)
120
        disp1 = self.dec_disp1(h1)
121
        pi1 = self.dec_pi1(h1)
122
        
123
        h2 = self.decoder2(h)
124
        mean2 = self.dec_mean2(h2)
125
        disp2 = self.dec_disp2(h2)
126
        pi2 = self.dec_pi2(h2)
127
128
        x0 = torch.cat([x1, x2], dim=-1)
129
        h0 = self.encoder(x0)
130
        num, lq = self.cal_latent(h0)
131
        return h0, num, lq, mean1, mean2, disp1, disp2, pi1, pi2
132
        
133
    def encodeBatch(self, X1, X2, batch_size=256):
134
        use_cuda = torch.cuda.is_available()
135
        if use_cuda:
136
            self.to(self.device)
137
        encoded = []
138
        self.eval()
139
        num = X1.shape[0]
140
        num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
141
        for batch_idx in range(num_batch):
142
            x1batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
143
            x2batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
144
            inputs1 = Variable(x1batch)
145
            inputs2 = Variable(x2batch)
146
            z,_,_,_,_,_,_,_,_ = self.forwardAE(inputs1, inputs2)
147
            encoded.append(z.data)
148
149
        encoded = torch.cat(encoded, dim=0)
150
        return encoded
151
152
    def kldloss(self, p, q):
153
        c1 = -torch.sum(p * torch.log(q), dim=-1)
154
        c2 = -torch.sum(p * torch.log(p), dim=-1)
155
        return torch.mean(c1 - c2)
156
157
    def pretrain_autoencoder(self, X1, X_raw1, sf1, X2, X_raw2, sf2, 
158
            batch_size=256, lr=0.001, epochs=400, ae_save=True, ae_weights='AE_weights.pth.tar'):
159
        num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
160
        dataset = TensorDataset(torch.Tensor(X1), torch.Tensor(X_raw1), torch.Tensor(sf1), torch.Tensor(X2), torch.Tensor(X_raw2), torch.Tensor(sf2))
161
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
162
        print("Pretraining stage")
163
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True)
164
        num = X1.shape[0]
165
        for epoch in range(epochs):
166
            loss_val = 0
167
            recon_loss1_val = 0
168
            recon_loss2_val = 0
169
            kl_loss_val = 0
170
            for batch_idx, (x1_batch, x_raw1_batch, sf1_batch, x2_batch, x_raw2_batch, sf2_batch) in enumerate(dataloader):
171
                x1_tensor = Variable(x1_batch).to(self.device)
172
                x_raw1_tensor = Variable(x_raw1_batch).to(self.device)
173
                sf1_tensor = Variable(sf1_batch).to(self.device)
174
                x2_tensor = Variable(x2_batch).to(self.device)
175
                x_raw2_tensor = Variable(x_raw2_batch).to(self.device)
176
                sf2_tensor = Variable(sf2_batch).to(self.device)
177
                zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forwardAE(x1_tensor, x2_tensor)
178
                recon_loss1 = self.zinb_loss(x=x_raw1_tensor, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sf1_tensor)
179
                recon_loss2 = self.zinb_loss(x=x_raw2_tensor, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sf2_tensor)
180
                lpbatch = self.target_distribution(lqbatch)
181
                lqbatch = lqbatch + torch.diag(torch.diag(z_num))
182
                lpbatch = lpbatch + torch.diag(torch.diag(z_num))
183
                kl_loss = self.kldloss(lpbatch, lqbatch) 
184
                if epoch+1 >= epochs * self.cutoff:
185
                   loss = recon_loss1 + recon_loss2 + kl_loss * self.phi1
186
                else:
187
                   loss = recon_loss1 + recon_loss2
188
                optimizer.zero_grad()
189
                loss.backward()
190
                optimizer.step()
191
192
                loss_val += loss.item() * len(x1_batch)
193
                recon_loss1_val += recon_loss1.item() * len(x1_batch)
194
                recon_loss2_val += recon_loss2.item() * len(x2_batch)
195
                if epoch+1 >= epochs * self.cutoff:
196
                    kl_loss_val += kl_loss.item() * len(x1_batch)
197
198
            loss_val = loss_val/num
199
            recon_loss1_val = recon_loss1_val/num
200
            recon_loss2_val = recon_loss2_val/num
201
            kl_loss_val = kl_loss_val/num
202
            if epoch%self.t == 0:
203
               print('Pretrain epoch {}, Total loss:{:.6f}, ZINB loss1:{:.6f}, ZINB loss2:{:.6f}, KL loss:{:.6f}'.format(epoch+1, loss_val, recon_loss1_val, recon_loss2_val, kl_loss_val))
204
205
        if ae_save:
206
            torch.save({'ae_state_dict': self.state_dict(),
207
                    'optimizer_state_dict': optimizer.state_dict()}, ae_weights)
208
209
    def save_checkpoint(self, state, index, filename):
210
        newfilename = os.path.join(filename, 'FTcheckpoint_%d.pth.tar' % index)
211
        torch.save(state, newfilename)
212
213
    def fit(self, X1, X_raw1, sf1, X2, X_raw2, sf2, y=None, lr=1., n_clusters = 4,
214
            batch_size=256, num_epochs=10, update_interval=1, tol=1e-3, save_dir=""):
215
        '''X: tensor data'''
216
        use_cuda = torch.cuda.is_available()
217
        if use_cuda:
218
            self.to(self.device)
219
        print("Clustering stage")
220
        X1 = torch.tensor(X1).to(self.device)
221
        X_raw1 = torch.tensor(X_raw1).to(self.device)
222
        sf1 = torch.tensor(sf1).to(self.device)
223
        X2 = torch.tensor(X2).to(self.device)
224
        X_raw2 = torch.tensor(X_raw2).to(self.device)
225
        sf2 = torch.tensor(sf2).to(self.device)
226
        self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim), requires_grad=True)
227
        optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95)
228
             
229
        print("Initializing cluster centers with kmeans.")
230
        kmeans = KMeans(n_clusters, n_init=20)
231
        Zdata = self.encodeBatch(X1, X2, batch_size=batch_size)
232
        #latent
233
        self.y_pred = kmeans.fit_predict(Zdata.data.cpu().numpy())
234
        self.y_pred_last = self.y_pred
235
        self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_))
236
237
        if y is not None:
238
            ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
239
            nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
240
            ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
241
            print('Initializing k-means: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (ami, nmi, ari))
242
        
243
        self.train()
244
        num = X1.shape[0]
245
        num_batch = int(math.ceil(1.0*X1.shape[0]/batch_size))
246
247
        final_ami, final_nmi, final_ari, final_epoch = 0, 0, 0, 0
248
249
        for epoch in range(num_epochs):
250
            if epoch%update_interval == 0:
251
                Zdata = self.encodeBatch(X1, X2, batch_size=batch_size)
252
                dist, _ = self.kmeans_loss(Zdata)
253
                self.y_pred = torch.argmin(dist, dim=1).data.cpu().numpy()
254
                if y is not None:
255
                    #acc2 = np.round(cluster_acc(y, self.y_pred), 5)
256
                    final_ami = ami = np.round(metrics.adjusted_mutual_info_score(y, self.y_pred), 5)
257
                    final_nmi = nmi = np.round(metrics.normalized_mutual_info_score(y, self.y_pred), 5)
258
                    final_ari = ari = np.round(metrics.adjusted_rand_score(y, self.y_pred), 5)
259
                    final_epoch = epoch+1
260
                    print('Clustering   %d: AMI= %.4f, NMI= %.4f, ARI= %.4f' % (epoch+1, ami, nmi, ari))
261
262
                # check stop criterion
263
                delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num
264
                self.y_pred_last = self.y_pred
265
                if epoch>0 and delta_label < tol:
266
                   print('delta_label ', delta_label, '< tol ', tol)
267
                   print("Reach tolerance threshold. Stopping training.")
268
                   break
269
                
270
                #save current model
271
                # if (epoch>0 and delta_label < tol) or epoch%10 == 0:
272
                    # self.save_checkpoint({'epoch': epoch+1,
273
                            # 'state_dict': self.state_dict(),
274
                            # 'mu': self.mu,
275
                            # 'y_pred': self.y_pred,
276
                            # 'y_pred_last': self.y_pred_last,
277
                            # 'y': y
278
                            # }, epoch+1, filename=save_dir)
279
                
280
            # train 1 epoch for clustering loss
281
            loss_val = 0.0
282
            recon_loss1_val = 0.0
283
            recon_loss2_val = 0.0
284
            cluster_loss_val = 0.0
285
            kl_loss_val = 0.0
286
            for batch_idx in range(num_batch):
287
                x1_batch = X1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
288
                x_raw1_batch = X_raw1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
289
                sf1_batch = sf1[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
290
                x2_batch = X2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
291
                x_raw2_batch = X_raw2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
292
                sf2_batch = sf2[batch_idx*batch_size : min((batch_idx+1)*batch_size, num)]
293
294
                inputs1 = Variable(x1_batch)
295
                rawinputs1 = Variable(x_raw1_batch)
296
                sfinputs1 = Variable(sf1_batch)
297
                inputs2 = Variable(x2_batch)
298
                rawinputs2 = Variable(x_raw2_batch)
299
                sfinputs2 = Variable(sf2_batch)
300
301
                zbatch, z_num, lqbatch, mean1_tensor, mean2_tensor, disp1_tensor, disp2_tensor, pi1_tensor, pi2_tensor = self.forward(inputs1, inputs2)
302
                _, cluster_loss = self.kmeans_loss(zbatch)
303
                recon_loss1 = self.zinb_loss(x=rawinputs1, mean=mean1_tensor, disp=disp1_tensor, pi=pi1_tensor, scale_factor=sfinputs1)
304
                recon_loss2 = self.zinb_loss(x=rawinputs2, mean=mean2_tensor, disp=disp2_tensor, pi=pi2_tensor, scale_factor=sfinputs2)
305
                target2 = self.target_distribution(lqbatch)
306
                lqbatch = lqbatch + torch.diag(torch.diag(z_num))
307
                target2 = target2 + torch.diag(torch.diag(z_num))
308
                kl_loss = self.kldloss(target2, lqbatch)
309
                loss = recon_loss1 + recon_loss2 + kl_loss * self.phi2 + cluster_loss * self.gamma
310
                optimizer.zero_grad()
311
                loss.backward()
312
#                torch.nn.utils.clip_grad_norm_(self.mu, 1)
313
                optimizer.step()
314
                cluster_loss_val += cluster_loss.data * len(inputs1)
315
                recon_loss1_val += recon_loss1.data * len(inputs1)
316
                recon_loss2_val += recon_loss2.data * len(inputs2)
317
                kl_loss_val += kl_loss.data * len(inputs1)
318
                loss_val = cluster_loss_val + recon_loss1_val + recon_loss2_val + kl_loss_val
319
320
            if epoch%self.t == 0:
321
               print("#Epoch %d: Total: %.6f Clustering Loss: %.6f ZINB Loss1: %.6f ZINB Loss2: %.6f KL Loss: %.6f" % (
322
                     epoch + 1, loss_val / num, cluster_loss_val / num, recon_loss1_val / num, recon_loss2_val / num, kl_loss_val / num))
323
324
        return self.y_pred, final_epoch