|
a |
|
b/autoencoder/autoencoder.py |
|
|
1 |
import sys |
|
|
2 |
import os |
|
|
3 |
import collections |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
|
|
|
7 |
lib_path = 'I:/code' |
|
|
8 |
if not os.path.exists(lib_path): |
|
|
9 |
lib_path = '/media/6T/.tianle/.lib' |
|
|
10 |
if os.path.exists(lib_path) and lib_path not in sys.path: |
|
|
11 |
sys.path.append(lib_path) |
|
|
12 |
|
|
|
13 |
import torch |
|
|
14 |
import torch.nn as nn |
|
|
15 |
|
|
|
16 |
from dl.models.basic_models import DenseLinear, get_list, get_attr |
|
|
17 |
from dl.utils.train import cosine_similarity |
|
|
18 |
|
|
|
19 |
class AutoEncoder(nn.Module): |
|
|
20 |
r"""Factorization autoencoder |
|
|
21 |
|
|
|
22 |
Args: |
|
|
23 |
|
|
|
24 |
Shape: |
|
|
25 |
|
|
|
26 |
Attributes: |
|
|
27 |
|
|
|
28 |
Examples:: |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
""" |
|
|
32 |
def __init__(self, in_dim, hidden_dims, num_classes, dense=True, residual=False, residual_layers='all', |
|
|
33 |
decoder_norm=False, decoder_norm_dim=0, uniform_decoder_norm=False, nonlinearity=nn.ReLU(), |
|
|
34 |
last_nonlinearity=True, bias=True): |
|
|
35 |
super(AutoEncoder, self).__init__() |
|
|
36 |
self.encoder = DenseLinear(in_dim, hidden_dims, nonlinearity=nonlinearity, last_nonlinearity=last_nonlinearity, |
|
|
37 |
dense=dense, residual=residual, residual_layers=residual_layers, forward_input=False, return_all=False, |
|
|
38 |
return_layers=None, bias=bias) |
|
|
39 |
self.decoder_norm = decoder_norm |
|
|
40 |
self.uniform_decoder_norm = uniform_decoder_norm |
|
|
41 |
if self.decoder_norm: |
|
|
42 |
self.decoder = nn.utils.weight_norm(nn.Linear(hidden_dims[-1], in_dim), 'weight', dim=decoder_norm_dim) |
|
|
43 |
if self.uniform_decoder_norm: |
|
|
44 |
self.decoder.weight_g.data = self.decoder.weight_g.new_ones(1) # This changed the tensor shape, but it's ok |
|
|
45 |
self.decoder.weight_g.requires_grad_(False) |
|
|
46 |
else: |
|
|
47 |
self.decoder = nn.Linear(hidden_dims[-1], in_dim) |
|
|
48 |
self.classifier = nn.Linear(hidden_dims[-1], num_classes) |
|
|
49 |
|
|
|
50 |
def forward(self, x): |
|
|
51 |
out = self.encoder(x) |
|
|
52 |
return self.classifier(out), self.decoder(out) |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
class MultiviewAE(nn.Module): |
|
|
56 |
r"""Multiview autoencoder. |
|
|
57 |
|
|
|
58 |
Args: |
|
|
59 |
in_dims: a list (or iterable) of integers |
|
|
60 |
hidden_dims: a list of ints if every view has the same hidden_dims; otherwise a list of lists of ints |
|
|
61 |
out_dim: for classification, out_dim = num_cls |
|
|
62 |
fuse_type: default 'sum', add up the outputs of all encoders; require all ouputs has the same dimensions |
|
|
63 |
if 'cat', concatenate the outputs of all encoders |
|
|
64 |
dense, residual, residual_layers, nonlinearity, last_nonlinearity, bias are passed to DenseLinear |
|
|
65 |
decoder_norm: if True, add forward prehook torch.nn.utils.weight_norm to decoder (a nn.Linear module) |
|
|
66 |
decoder_norm_dim: default 0; pass to torch.nn.utils.weight_norm |
|
|
67 |
uniform_decoder_norm: if True, ensure that decoder weight norm is 1 for dim=decoder_norm_dim |
|
|
68 |
|
|
|
69 |
Shape: |
|
|
70 |
Input: can be a list of tensors or a single tensor which will be splitted into a list |
|
|
71 |
Output: two heads: score matrix of shape (N, out_dim), concatenated decoder output: (N, sum(in_dims)) |
|
|
72 |
|
|
|
73 |
Attributes: |
|
|
74 |
A list of DenseLinear modules as encoders and decoders |
|
|
75 |
An nn.Linear as output layer (e.g., class score matrix) |
|
|
76 |
|
|
|
77 |
Examples: |
|
|
78 |
>>> x = torch.randn(10, 5) |
|
|
79 |
>>> model = MultiviewAE([2,3], [5, 5], 7) |
|
|
80 |
>>> y = model(x) |
|
|
81 |
>>> y[0].shape, y[1].shape |
|
|
82 |
|
|
|
83 |
""" |
|
|
84 |
def __init__(self, in_dims, hidden_dims, out_dim, fuse_type='sum', dense=False, residual=True, |
|
|
85 |
residual_layers='all', decoder_norm=False, decoder_norm_dim=0, uniform_decoder_norm=False, |
|
|
86 |
nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True): |
|
|
87 |
super(MultiviewAE, self).__init__() |
|
|
88 |
self.num_views = len(in_dims) |
|
|
89 |
self.in_dims = in_dims |
|
|
90 |
self.out_dim = out_dim |
|
|
91 |
self.fuse_type = fuse_type |
|
|
92 |
if not isinstance(hidden_dims[0], collections.Iterable): |
|
|
93 |
# hidden_dims is a list of ints, which means all views have the same hidden dims |
|
|
94 |
hidden_dims = [hidden_dims] * self.num_views |
|
|
95 |
self.hidden_dims = hidden_dims |
|
|
96 |
assert len(self.hidden_dims) == self.num_views and isinstance(self.hidden_dims[0], collections.Iterable) |
|
|
97 |
self.encoders = nn.ModuleList() |
|
|
98 |
self.decoders = nn.ModuleList() |
|
|
99 |
for in_dim, hidden_dim in zip(in_dims, hidden_dims): |
|
|
100 |
self.encoders.append(DenseLinear(in_dim, hidden_dim, nonlinearity=nonlinearity, |
|
|
101 |
last_nonlinearity=last_nonlinearity, dense=dense, forward_input=False, return_all=False, |
|
|
102 |
return_layers=None, bias=bias, residual=residual, residual_layers=residual_layers)) |
|
|
103 |
decoder = nn.Linear(hidden_dim[-1], in_dim) |
|
|
104 |
if decoder_norm: |
|
|
105 |
torch.nn.utils.weight_norm(decoder, 'weight', dim=decoder_norm_dim) |
|
|
106 |
if uniform_decoder_norm: |
|
|
107 |
decoder.weight_g.data = decoder.weight_g.new_ones(decoder.weight_g.size()) |
|
|
108 |
decoder.weight_g.requires_grad_(False) |
|
|
109 |
self.decoders.append(decoder) |
|
|
110 |
self.fuse_dims = [hidden_dim[-1] for hidden_dim in self.hidden_dims] |
|
|
111 |
if self.fuse_type == 'sum': |
|
|
112 |
fuse_dim = self.fuse_dims[0] |
|
|
113 |
for d in self.fuse_dims: |
|
|
114 |
assert d == fuse_dim |
|
|
115 |
elif self.fuse_type == 'cat': |
|
|
116 |
fuse_dim = sum(self.fuse_dims) |
|
|
117 |
else: |
|
|
118 |
raise ValueError(f"fuse_type should be 'sum' or 'cat', but is {fuse_type}") |
|
|
119 |
self.output = nn.Linear(fuse_dim, out_dim) |
|
|
120 |
|
|
|
121 |
def forward(self, xs): |
|
|
122 |
if isinstance(xs, torch.Tensor): |
|
|
123 |
xs = xs.split(self.in_dims, dim=1) |
|
|
124 |
# assert len(xs) == self.num_views |
|
|
125 |
encoder_out = [] |
|
|
126 |
decoder_out = [] |
|
|
127 |
for i, x in enumerate(xs): |
|
|
128 |
out = self.encoders[i](x) |
|
|
129 |
encoder_out.append(out) |
|
|
130 |
decoder_out.append(self.decoders[i](out)) |
|
|
131 |
if self.fuse_type == 'sum': |
|
|
132 |
out = torch.stack(encoder_out, dim=-1).mean(dim=-1) |
|
|
133 |
else: |
|
|
134 |
out = torch.cat(encoder_out, dim=-1) |
|
|
135 |
out = self.output(out) |
|
|
136 |
return out, torch.cat(decoder_out, dim=-1), torch.cat(encoder_out, dim=-1) |
|
|
137 |
|
|
|
138 |
|
|
|
139 |
def get_interaction_loss(interaction_mat, w, loss_type='graph_laplacian', normalize=True): |
|
|
140 |
"""Calculate loss on the inconsistency between feature representations w (N*D) |
|
|
141 |
and feature interaction network interaction_mat (N*N) |
|
|
142 |
A trivial solution is all features (row vectors of w) have cosine similarity = 1 or distance = 0 |
|
|
143 |
|
|
|
144 |
Args: |
|
|
145 |
interaction_mat: non-negative symmetric torch.Tensor with shape (N, N) |
|
|
146 |
w: feature representation tensor with shape (N, D) |
|
|
147 |
normalize: if True, call w = w / w.norm(p=2, dim=1, keepdim=True) /np.sqrt(w.size(0)) |
|
|
148 |
for loss_type = 'graph_laplacian' or 'dot_product', |
|
|
149 |
this makes sure w.norm() = 1 and the row vectors of w have the same norm: len(torch.unique(w.norm(dim=1)))==1 |
|
|
150 |
call loss = loss / w.size(0) for loss_type = 'cosine_similarity'; |
|
|
151 |
By doing this we ensure the number of features is factored out; |
|
|
152 |
this is useful for combining losses from multi-views. |
|
|
153 |
|
|
|
154 |
See Loss_feature_interaction for more documentation |
|
|
155 |
|
|
|
156 |
""" |
|
|
157 |
if loss_type == 'cosine_similarity': |
|
|
158 |
# -(|cos(w,w)| * interaction_mat).sum() |
|
|
159 |
cos = cosine_similarity(w).abs() # get the absolute value of cosine simiarity |
|
|
160 |
loss = -(cos * interaction_mat).sum() |
|
|
161 |
if normalize: |
|
|
162 |
loss = loss / w.size(0) |
|
|
163 |
elif loss_type == 'graph_laplacian': |
|
|
164 |
# trace(w' * L * w) |
|
|
165 |
if normalize: |
|
|
166 |
w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0)) |
|
|
167 |
interaction_mat = interaction_mat / interaction_mat.norm() # this will ensure interaction_mat is normalized |
|
|
168 |
diag = torch.diag(interaction_mat.sum(dim=1)) |
|
|
169 |
L_interaction_mat = diag - interaction_mat |
|
|
170 |
loss = torch.diagonal(torch.mm(torch.mm(w.t(), L_interaction_mat), w)).sum() |
|
|
171 |
elif loss_type == 'dot_product': |
|
|
172 |
# pairwise distance mat * interaction mat |
|
|
173 |
if normalize: |
|
|
174 |
w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0)) |
|
|
175 |
d = torch.sum(w*w, dim=1) # if normalize is True, then d is a vector of the same element 1/w.size(0) |
|
|
176 |
dist = d.unsqueeze(1) + d - 2*torch.mm(w, w.t()) |
|
|
177 |
loss = (dist * interaction_mat).sum() |
|
|
178 |
# loss = (dist / dist.norm() * interaction_mat).sum() # This is an alternative to 'normalize' loss |
|
|
179 |
else: |
|
|
180 |
raise ValueError(f"loss_type can only be 'cosine_similarity', " |
|
|
181 |
f"graph_laplacian' or 'dot_product', but is {loss_type}") |
|
|
182 |
return loss |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
class Loss_feature_interaction(nn.Module): |
|
|
186 |
r"""A customized loss function for a graph Laplacian constraint on the feature interaction network |
|
|
187 |
For factorization autoencoder model, the decoder weights can be seen as feature representations; |
|
|
188 |
This loss measures the inconsistency between learned feature representations and their interaction network. |
|
|
189 |
A trivial solution is all features have cosine similarity = 1 or distance = 0 |
|
|
190 |
|
|
|
191 |
Args: |
|
|
192 |
interaction_mat: torch.Tensor of shape (N, N), a non-negative (symmetric) matrix; |
|
|
193 |
or a list of matrices; each is an interaction mat; |
|
|
194 |
To control the magnitude of the loss, it is preferred to have argument interaction_mat.norm() = 1 |
|
|
195 |
loss_type: if loss_type == 'cosine_similarity', calculate -(cos(m, m).abs() * interaction_mat).sum() |
|
|
196 |
if loss_type == 'graph_laplacian' (faster), calculate trace(m' * L * m) |
|
|
197 |
if loss_type == 'dot_product', calculate dist(m) * interaction_mat |
|
|
198 |
where dist(m) is the pairwise distance matrix of features; the name 'dot_product' is misleading |
|
|
199 |
If all features have norm 1, all three types are equivalent in a sense |
|
|
200 |
cosine_similarity is preferred because the magnitude of features are implicitly ignored, |
|
|
201 |
while the other two will be affected by the magnitude of features. |
|
|
202 |
weight_path: default ['decoder', 'weight'], with the goal to get w = model.decoder.weight |
|
|
203 |
normalize: pass it to get_interaction_loss; |
|
|
204 |
if True, call w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0)) |
|
|
205 |
for loss_type 'graph_laplacian' or 'dot_product', |
|
|
206 |
this makes sure each row vector of w has the same norm, and w.norm() = 1 |
|
|
207 |
call loss = loss / w.size(0) for loss_type = 'cosine_similarity'; |
|
|
208 |
By doing this we ensure the number of features is factored out; |
|
|
209 |
this is useful for combining losses from multi-views. |
|
|
210 |
|
|
|
211 |
Inputs: |
|
|
212 |
model: the above defined AutoEnoder model or other model |
|
|
213 |
or given weight matrix w |
|
|
214 |
if interaction_mat has shape (N,N), then w has shape (N, D) |
|
|
215 |
|
|
|
216 |
Returns: |
|
|
217 |
loss: torch.Tensor that can call loss.backward() |
|
|
218 |
""" |
|
|
219 |
|
|
|
220 |
def __init__(self, interaction_mat, loss_type='graph_laplacian', weight_path=['decoder', 'weight'], |
|
|
221 |
normalize=True): |
|
|
222 |
super(Loss_feature_interaction, self).__init__() |
|
|
223 |
self.loss_type = loss_type |
|
|
224 |
self.weight_path = weight_path |
|
|
225 |
self.normalize = normalize |
|
|
226 |
# If interaction_mat is a list, self.sections will be the used for splitting the weight matrix |
|
|
227 |
self.sections = None # when interaction_mat is a single matrix, self.sections is None |
|
|
228 |
if isinstance(interaction_mat, (list, tuple)): |
|
|
229 |
if normalize: # ensure interaction_mat is normalized |
|
|
230 |
interaction_mat = [m/m.norm() for m in interaction_mat] |
|
|
231 |
self.sections = [m.shape[0] for m in interaction_mat] |
|
|
232 |
else: |
|
|
233 |
if normalize: # ensure interaction_mat is normalized |
|
|
234 |
interaction_mat = interaction_mat / interaction_mat.norm() |
|
|
235 |
if self.loss_type == 'graph_laplacian': |
|
|
236 |
# precalculate self.L_interaction_mat save some compute for each forward pass |
|
|
237 |
if self.sections is None: |
|
|
238 |
diag = torch.diag(interaction_mat.sum(dim=1)) |
|
|
239 |
self.L_interaction_mat = diag - interaction_mat # Graph Laplacian; should I normalize it? |
|
|
240 |
else: |
|
|
241 |
self.L_interaction_mat = [] |
|
|
242 |
for mat in interaction_mat: |
|
|
243 |
diag = torch.diag(mat.sum(dim=1)) |
|
|
244 |
self.L_interaction_mat.append(diag - mat) |
|
|
245 |
else: # we don't need to store interaction_mat for loss_type=='graph_laplacian' |
|
|
246 |
self.interaction_mat = interaction_mat |
|
|
247 |
|
|
|
248 |
def forward(self, model=None, w=None): |
|
|
249 |
if w is None: |
|
|
250 |
w = get_attr(model, self.weight_path) |
|
|
251 |
if self.sections is None: |
|
|
252 |
# There is only one interaction matrix; self.interaction_mat is a torch.Tensor |
|
|
253 |
if self.loss_type == 'graph_laplacian': |
|
|
254 |
# Used precalculated L_interaction_mat to save some time |
|
|
255 |
if self.normalize: |
|
|
256 |
# interaction_mat had already been normalized during initialization |
|
|
257 |
w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0)) |
|
|
258 |
return torch.diagonal(torch.mm(torch.mm(w.t(), self.L_interaction_mat), w)).sum() |
|
|
259 |
else: |
|
|
260 |
return get_interaction_loss(self.interaction_mat, w, loss_type=self.loss_type, normalize=self.normalize) |
|
|
261 |
else: |
|
|
262 |
# self.interaction_mat is a list of torch.Tensors |
|
|
263 |
if isinstance(w, torch.Tensor): |
|
|
264 |
w = w.split(self.sections, dim=0) |
|
|
265 |
if self.loss_type == 'graph_laplacian': # handle 'graph_laplacian' differently to save time during training |
|
|
266 |
loss = 0 |
|
|
267 |
for w_, L in zip(w, self.L_interaction_mat): |
|
|
268 |
if self.normalize: # make sure w_.norm() = 1 and each row vector of w_ has the same norm |
|
|
269 |
w_ = w_ / w_.norm(p=2, dim=1, keepdim=True) / np.sqrt(w_.size(0)) |
|
|
270 |
loss += torch.diagonal(torch.mm(torch.mm(w_.t(), L), w_)).sum() |
|
|
271 |
return loss |
|
|
272 |
# for the case 'cosine_similarity' and 'dot_product' |
|
|
273 |
return sum([get_interaction_loss(mat, w_, loss_type=self.loss_type, normalize=self.normalize) |
|
|
274 |
for mat, w_ in zip(self.interaction_mat, w)]) |
|
|
275 |
|
|
|
276 |
|
|
|
277 |
class Loss_view_similarity(nn.Module): |
|
|
278 |
r"""The input is a multi-view representation of the same set of patients, |
|
|
279 |
i.e., a set of matrices with shape (num_samples, feature_dim). feature_dim can be different for each view |
|
|
280 |
This loss will penalize the inconsistency among different views. |
|
|
281 |
This is somewhat limited, because different views should have both shared and complementary information |
|
|
282 |
This loss only encourages the shared information across views, |
|
|
283 |
which may or may not be good for certain applications. |
|
|
284 |
A trivial solution for this is multi-view representation are all the same; then loss -> -1 |
|
|
285 |
The two loss_types 'circle' and 'hub' can be quite different and unstable. |
|
|
286 |
'circle' tries to make all feature representations across views have high cosine similarity, |
|
|
287 |
while 'hub' only tries to make feature representations within each view have high cosine similarity; |
|
|
288 |
by multiplying 'mean-feature' target with 'hub' loss_type, it might 'magically' capture both within-view and |
|
|
289 |
cross-view similarity; set as default choice; but my limited experimental results do not validate this; |
|
|
290 |
instead, 'circle' and 'hub' are dominant, while explicit_target and cal_target do not make a big difference |
|
|
291 |
Cosine similarity are used here; To do: other similarity metrics |
|
|
292 |
|
|
|
293 |
Args: |
|
|
294 |
sections: a list of integers (or an int); this is used to split the input matrix into chunks; |
|
|
295 |
each chunk corresponds to one view representation. |
|
|
296 |
If input xs is not a torch.Tensor, this will not be used; assume xs to be a list of torch.Tensors |
|
|
297 |
sections being an int implies all feature dim are the same, set sections = feature_dim, NOT num_sections! |
|
|
298 |
loss_type: supose there are three views x1, x2, x3; let s_ij = cos(x_i,x_j), s_i = cos(x_i,x_i) |
|
|
299 |
if loss_type=='cicle', similarity = s12*s23*target if fusion_type=='multiply'; s12+s23 if fusion_type=='sum' |
|
|
300 |
This is fastest but requires x1, x2, x3 have the same shape |
|
|
301 |
if loss_type=='hub', similarity=s1*s2*s3*target if fusion_type=='multiply'; |
|
|
302 |
similarity=|s1|+|s2|+|s3|+|target| if fusion_type=='sum' |
|
|
303 |
Implicitly, target=1 (fusion_type=='multiply) or 0 (fusion_type=='sum') if explicit_target is False |
|
|
304 |
if graph_laplacian is False: |
|
|
305 |
loss = - similarity.abs().mean() |
|
|
306 |
else: |
|
|
307 |
s = similarity.abs(); L_s = torch.diag(sum(s, axis=1)) - s #graph laplacian |
|
|
308 |
loss = sum_i(x_i * L_s * x_i^T) |
|
|
309 |
explicit_target: if False, target=1 (fusion_type=='multiply) or 0 (fusion_type=='sum') implicitly |
|
|
310 |
if True, use given target or calculate it from xs |
|
|
311 |
# to do handle the case when we only use the explicitly given target |
|
|
312 |
cal_target: if 'mean-similarity', target = (cos(x1,x1) + cos(x2,x2) + cos(x3,x3))/3 |
|
|
313 |
if 'mean-feature', x = (x1+x2+x3)/3; target = cos(x,x); this requires x1,x2,x3 have the same shape |
|
|
314 |
target: default None; only used when explicit_target is True |
|
|
315 |
This saves computation if target is provided in advance or passed as input |
|
|
316 |
fusion_type: if 'multiply', similarity=product(similarities); if 'sum', similarity=sum(|similarities|); |
|
|
317 |
work with loss_type |
|
|
318 |
graph_laplacian: if graph_laplacian is False: |
|
|
319 |
loss = - similarity.abs().mean() |
|
|
320 |
else: |
|
|
321 |
s = similarity.abs(); L_s = torch.diag(sum(s, axis=1)) - s #graph laplacian |
|
|
322 |
loss = sum_i(x_i * L_s * x_i^T) |
|
|
323 |
|
|
|
324 |
Inputs: |
|
|
325 |
xs: a set of torch.Tensor matrices of (num_samples, feature_dim), |
|
|
326 |
or a single matrix with self.sections being specified |
|
|
327 |
target: the target cosine similarity matrix; default None; |
|
|
328 |
if not given, first check if self.targets is given; |
|
|
329 |
if self.targets is None, then calulate it according to cal_target; |
|
|
330 |
only used when self.explicit_target is True |
|
|
331 |
|
|
|
332 |
Output: |
|
|
333 |
loss = -similarity.abs().mean() if graph_laplacian is False # Is this the right way to do it? |
|
|
334 |
= sum_i(x_i * L_s * x_i^T) if graph_laplacian is True # call get_interaction_loss() |
|
|
335 |
|
|
|
336 |
""" |
|
|
337 |
def __init__(self, sections=None, loss_type='hub', explicit_target=False, |
|
|
338 |
cal_target='mean-feature', target=None, fusion_type='multiply', graph_laplacian=False): |
|
|
339 |
super(Loss_view_similarity, self).__init__() |
|
|
340 |
self.sections = sections |
|
|
341 |
if self.sections is not None: |
|
|
342 |
if not isinstance(self.sections, int): |
|
|
343 |
assert len(self.sections) >= 2 |
|
|
344 |
self.loss_type = loss_type |
|
|
345 |
assert self.loss_type in ['circle', 'hub'] |
|
|
346 |
self.explicit_target = explicit_target |
|
|
347 |
self.cal_target = cal_target |
|
|
348 |
self.target = target |
|
|
349 |
self.fusion_type = fusion_type |
|
|
350 |
self.graph_laplacian = graph_laplacian |
|
|
351 |
# I got nan losses easily for whenever graph_laplacian is True, especially the following case; did not know why |
|
|
352 |
# probably I need normalize similarity during every forward? |
|
|
353 |
assert not (fusion_type=='multiply' and graph_laplacian) and not (loss_type=='circle' and graph_laplacian) |
|
|
354 |
|
|
|
355 |
def forward(self, xs, target=None): |
|
|
356 |
if isinstance(xs, torch.Tensor): |
|
|
357 |
# make sure xs is a list of tensors corresponding to multiple views |
|
|
358 |
# this requires self.sections to valid |
|
|
359 |
xs = xs.split(self.sections, dim=1) |
|
|
360 |
# assert len(xs) >= 2 # comment this to save time for many forward passes |
|
|
361 |
similarity = 1 |
|
|
362 |
if self.loss_type == 'circle': |
|
|
363 |
# assert xs[i-1].shape == xs[i].shape |
|
|
364 |
# this saves computation |
|
|
365 |
similarity_mats = [cosine_similarity(xs[i-1], xs[i]) for i in range(1, len(xs))] |
|
|
366 |
similarity_mats = [(m+m.t())/2 for m in similarity_mats] # make it symmetric |
|
|
367 |
elif self.loss_type == 'hub': |
|
|
368 |
similarity_mats = [cosine_similarity(x) for x in xs] |
|
|
369 |
if self.fusion_type=='multiply': |
|
|
370 |
for m in similarity_mats: |
|
|
371 |
similarity = similarity * m # element multiplication ensures the larget value to be 1 |
|
|
372 |
elif self.fusion_type=='sum': |
|
|
373 |
similarity = sum(similarity_mats) / len(similarity_mats) # calculate mean to ensure the largest value to be 1 |
|
|
374 |
|
|
|
375 |
if self.explicit_target: |
|
|
376 |
if target is None: |
|
|
377 |
if self.target is None: |
|
|
378 |
if self.cal_target == 'mean-similarity': |
|
|
379 |
target = torch.stack(similarity_mats, dim=0).mean(0) |
|
|
380 |
elif self.cal_target == 'mean-feature': |
|
|
381 |
x = torch.stack(xs, -1).mean(-1) # the list of view matrices must have the same dimension |
|
|
382 |
target = cosine_similarity(x) |
|
|
383 |
else: |
|
|
384 |
raise ValueError(f'cal_target should be mean-similarity or mean-feature, but is {self.cal_target}') |
|
|
385 |
else: |
|
|
386 |
target = self.target |
|
|
387 |
if self.fusion_type=='multiply': |
|
|
388 |
similarity = similarity * target |
|
|
389 |
elif self.fusion_type=='sum': |
|
|
390 |
similarity = (len(similarity_mats)*similarity + target) / (len(similarity_mats) + 1) # Moving average |
|
|
391 |
similarity = similarity.abs() # ensure similarity to be non-negative |
|
|
392 |
if self.graph_laplacian: |
|
|
393 |
# Easily get nan loss when it is True; do not know why |
|
|
394 |
return sum([get_interaction_loss(similarity, w, loss_type='graph_laplacian', normalize=True) for w in xs]) / len(xs) |
|
|
395 |
else: |
|
|
396 |
return -similarity.mean() # to ensure the loss is within range [-1, 0] |
|
|
397 |
|
|
|
398 |
|