Diff of /fusion.py [000000] .. [2095ed]

Switch to unified view

a b/fusion.py
1
import torch
2
import torch.nn as nn
3
from utils import init_max_weights
4
5
6
class BilinearFusion(nn.Module):
7
    def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, dim1=32, dim2=32, scale_dim1=1, scale_dim2=1, mmhid=64, dropout_rate=0.25):
8
        super(BilinearFusion, self).__init__()
9
        self.skip = skip
10
        self.use_bilinear = use_bilinear
11
        self.gate1 = gate1
12
        self.gate2 = gate2
13
14
        dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2
15
        skip_dim = dim1+dim2+2 if skip else 0
16
17
        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
18
        self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1))
19
        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
20
21
        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
22
        self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2))
23
        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
24
25
        self.post_fusion_dropout = nn.Dropout(p=dropout_rate)
26
        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
27
        self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
28
        init_max_weights(self)
29
30
    def forward(self, vec1, vec2):
31
        ### Gated Multimodal Units
32
        if self.gate1:
33
            h1 = self.linear_h1(vec1)
34
            z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1))
35
            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
36
        else:
37
            o1 = self.linear_o1(vec1)
38
39
        if self.gate2:
40
            h2 = self.linear_h2(vec2)
41
            z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1))
42
            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
43
        else:
44
            o2 = self.linear_o2(vec2)
45
46
        ### Fusion
47
        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
48
        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
49
        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) # BATCH_SIZE X 1024
50
        out = self.post_fusion_dropout(o12)
51
        out = self.encoder1(out)
52
        if self.skip: out = torch.cat((out, o1, o2), 1)
53
        out = self.encoder2(out)
54
        return out
55
56
57
class TrilinearFusion_A(nn.Module):
58
    def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
59
        super(TrilinearFusion_A, self).__init__()
60
        self.skip = skip
61
        self.use_bilinear = use_bilinear
62
        self.gate1 = gate1
63
        self.gate2 = gate2
64
        self.gate3 = gate3
65
66
        dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1//scale_dim1, dim2//scale_dim2, dim3//scale_dim3
67
        skip_dim = dim1+dim2+dim3+3 if skip else 0
68
69
        ### Path
70
        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
71
        self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim1))
72
        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
73
74
        ### Graph
75
        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
76
        self.linear_z2 = nn.Bilinear(dim2_og, dim3_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim2_og+dim3_og, dim2))
77
        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
78
79
        ### Omic
80
        self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
81
        self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim3))
82
        self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate))
83
84
        self.post_fusion_dropout = nn.Dropout(p=0.25)
85
        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1)*(dim3+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
86
        self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
87
        
88
        init_max_weights(self)
89
90
    def forward(self, vec1, vec2, vec3):
91
        ### Gated Multimodal Units
92
        if self.gate1:
93
            h1 = self.linear_h1(vec1)
94
            z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic
95
            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
96
        else:
97
            o1 = self.linear_o1(vec1)
98
99
        if self.gate2:
100
            h2 = self.linear_h2(vec2)
101
            z2 = self.linear_z2(vec2, vec3) if self.use_bilinear else self.linear_z2(torch.cat((vec2, vec3), dim=1)) # Gate Graph with Omic
102
            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
103
        else:
104
            o2 = self.linear_o2(vec2)
105
106
        if self.gate3:
107
            h3 = self.linear_h3(vec3)
108
            z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3(torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path
109
            o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
110
        else:
111
            o3 = self.linear_o3(vec3)
112
113
        ### Fusion
114
        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
115
        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
116
        o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1)
117
        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1)
118
        o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1)
119
        out = self.post_fusion_dropout(o123)
120
        out = self.encoder1(out)
121
        if self.skip: out = torch.cat((out, o1, o2, o3), 1)
122
        out = self.encoder2(out)
123
        return out
124
125
126
class TrilinearFusion_B(nn.Module):
127
    def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
128
        super(TrilinearFusion_B, self).__init__()
129
        self.skip = skip
130
        self.use_bilinear = use_bilinear
131
        self.gate1 = gate1
132
        self.gate2 = gate2
133
        self.gate3 = gate3
134
135
        dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1//scale_dim1, dim2//scale_dim2, dim3//scale_dim3
136
        skip_dim = dim1+dim2+dim3+3 if skip else 0
137
138
        ### Path
139
        self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
140
        self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim1))
141
        self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate))
142
143
        ### Graph
144
        self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
145
        self.linear_z2 = nn.Bilinear(dim2_og, dim1_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim2_og+dim1_og, dim2))
146
        self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate))
147
148
        ### Omic
149
        self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
150
        self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim3_og, dim3))
151
        self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate))
152
153
        self.post_fusion_dropout = nn.Dropout(p=0.25)
154
        self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1)*(dim3+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
155
        self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate))
156
        
157
        init_max_weights(self)
158
159
    def forward(self, vec1, vec2, vec3):
160
        ### Gated Multimodal Units
161
        if self.gate1:
162
            h1 = self.linear_h1(vec1)
163
            z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic
164
            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1)
165
        else:
166
            o1 = self.linear_o1(vec1)
167
168
        if self.gate2:
169
            h2 = self.linear_h2(vec2)
170
            z2 = self.linear_z2(vec2, vec1) if self.use_bilinear else self.linear_z2(torch.cat((vec2, vec1), dim=1)) # Gate Graph with Omic
171
            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
172
        else:
173
            o2 = self.linear_o2(vec2)
174
175
        if self.gate3:
176
            h3 = self.linear_h3(vec3)
177
            z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3(torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path
178
            o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
179
        else:
180
            o3 = self.linear_o3(vec3)
181
182
        ### Fusion
183
        o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1)
184
        o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1)
185
        o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1)
186
        o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1)
187
        o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1)
188
        out = self.post_fusion_dropout(o123)
189
        out = self.encoder1(out)
190
        if self.skip: out = torch.cat((out, o1, o2, o3), 1)
191
        out = self.encoder2(out)
192
        return out