Diff of /autoencoder_model.py [000000] .. [4782c6]

Switch to unified view

a b/autoencoder_model.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
# @Time    : 2021/8/7 14:01
4
# @Author  : Li Xiao
5
# @File    : autoencoder_model.py
6
import torch
7
from torch import nn
8
from matplotlib import pyplot as plt
9
10
class MMAE(nn.Module):
11
    def __init__(self, in_feas_dim, latent_dim, a=0.4, b=0.3, c=0.3):
12
        '''
13
        :param in_feas_dim: a list, input dims of omics data
14
        :param latent_dim: dim of latent layer
15
        :param a: weight of omics data type 1
16
        :param b: weight of omics data type 2
17
        :param c: weight of omics data type 3
18
        '''
19
        super(MMAE, self).__init__()
20
        self.a = a
21
        self.b = b
22
        self.c = c
23
        self.in_feas = in_feas_dim
24
        self.latent = latent_dim
25
26
        #encoders, multi channel input
27
        self.encoder_omics_1 = nn.Sequential(
28
            nn.Linear(self.in_feas[0], self.latent),
29
            nn.BatchNorm1d(self.latent),
30
            nn.Sigmoid()
31
        )
32
        self.encoder_omics_2 = nn.Sequential(
33
            nn.Linear(self.in_feas[1], self.latent),
34
            nn.BatchNorm1d(self.latent),
35
            nn.Sigmoid()
36
        )
37
        self.encoder_omics_3 = nn.Sequential(
38
            nn.Linear(self.in_feas[2], self.latent),
39
            nn.BatchNorm1d(self.latent),
40
            nn.Sigmoid()
41
        )
42
        #decoders
43
        self.decoder_omics_1 = nn.Sequential(nn.Linear(self.latent, self.in_feas[0]))
44
        self.decoder_omics_2 = nn.Sequential(nn.Linear(self.latent, self.in_feas[1]))
45
        self.decoder_omics_3 = nn.Sequential(nn.Linear(self.latent, self.in_feas[2]))
46
47
        #Variable initialization
48
        for name, param in MMAE.named_parameters(self):
49
            if 'weight' in name:
50
                torch.nn.init.normal_(param, mean=0, std=0.1)
51
            if 'bias' in name:
52
                torch.nn.init.constant_(param, val=0)
53
54
    def forward(self, omics_1, omics_2, omics_3):
55
        '''
56
        :param omics_1: omics data 1
57
        :param omics_2: omics data 2
58
        :param omics_3: omics data 3
59
        '''
60
        encoded_omics_1 = self.encoder_omics_1(omics_1)
61
        encoded_omics_2 = self.encoder_omics_2(omics_2)
62
        encoded_omics_3 = self.encoder_omics_3(omics_3)
63
        latent_data = torch.mul(encoded_omics_1, self.a) + torch.mul(encoded_omics_2, self.b) + torch.mul(encoded_omics_3, self.c)
64
        decoded_omics_1 = self.decoder_omics_1(latent_data)
65
        decoded_omics_2 = self.decoder_omics_2(latent_data)
66
        decoded_omics_3 = self.decoder_omics_3(latent_data)
67
        return latent_data, decoded_omics_1, decoded_omics_2, decoded_omics_3
68
69
    def train_MMAE(self, train_loader, learning_rate=0.001, device=torch.device('cpu'), epochs=100):
70
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
71
        loss_fn = nn.MSELoss()
72
        loss_ls = []
73
        for epoch in range(epochs):
74
            train_loss_sum = 0.0       #Record the loss of each epoch
75
            for (x,y) in train_loader:
76
                omics_1 = x[:, :self.in_feas[0]]
77
                omics_2 = x[:, self.in_feas[0]:self.in_feas[0]+self.in_feas[1]]
78
                omics_3 = x[:, self.in_feas[0]+self.in_feas[1]:self.in_feas[0]+self.in_feas[1]+self.in_feas[2]]
79
80
                omics_1 = omics_1.to(device)
81
                omics_2 = omics_2.to(device)
82
                omics_3 = omics_3.to(device)
83
84
                latent_data, decoded_omics_1, decoded_omics_2, decoded_omics_3 = self.forward(omics_1, omics_2, omics_3)
85
                loss = self.a*loss_fn(decoded_omics_1, omics_1)+ self.b*loss_fn(decoded_omics_2, omics_2) + self.c*loss_fn(decoded_omics_3, omics_3)
86
                optimizer.zero_grad()
87
                loss.backward()
88
                optimizer.step()
89
90
                train_loss_sum += loss.sum().item()
91
92
            loss_ls.append(train_loss_sum)
93
            print('epoch: %d | loss: %.4f' % (epoch + 1, train_loss_sum))
94
95
            #save the model every 10 epochs, used for feature extraction
96
            if (epoch+1) % 10 ==0:
97
                torch.save(self, 'model/AE/model_{}.pkl'.format(epoch+1))
98
99
        #draw the training loss curve
100
        plt.plot([i + 1 for i in range(epochs)], loss_ls)
101
        plt.xlabel('epochs')
102
        plt.ylabel('loss')
103
        plt.savefig('result/AE_train_loss.png')