Switch to unified view

a b/autoencoder/autoencoder.py
1
import sys
2
import os
3
import collections
4
5
import numpy as np
6
7
lib_path = 'I:/code'
8
if not os.path.exists(lib_path):
9
  lib_path = '/media/6T/.tianle/.lib'
10
if os.path.exists(lib_path) and lib_path not in sys.path:
11
  sys.path.append(lib_path)
12
13
import torch
14
import torch.nn as nn
15
16
from dl.models.basic_models import DenseLinear, get_list, get_attr
17
from dl.utils.train import cosine_similarity
18
19
class AutoEncoder(nn.Module):
20
  r"""Factorization autoencoder
21
  
22
  Args:
23
  
24
  Shape:
25
  
26
  Attributes:
27
  
28
  Examples::
29
  
30
  
31
  """
32
  def __init__(self, in_dim, hidden_dims, num_classes, dense=True, residual=False, residual_layers='all',
33
    decoder_norm=False, decoder_norm_dim=0, uniform_decoder_norm=False, nonlinearity=nn.ReLU(), 
34
    last_nonlinearity=True, bias=True):
35
    super(AutoEncoder, self).__init__()
36
    self.encoder = DenseLinear(in_dim, hidden_dims, nonlinearity=nonlinearity, last_nonlinearity=last_nonlinearity, 
37
      dense=dense, residual=residual, residual_layers=residual_layers, forward_input=False, return_all=False, 
38
      return_layers=None, bias=bias)
39
    self.decoder_norm = decoder_norm
40
    self.uniform_decoder_norm = uniform_decoder_norm
41
    if self.decoder_norm:
42
      self.decoder = nn.utils.weight_norm(nn.Linear(hidden_dims[-1], in_dim), 'weight', dim=decoder_norm_dim)
43
      if self.uniform_decoder_norm:
44
        self.decoder.weight_g.data = self.decoder.weight_g.new_ones(1) # This changed the tensor shape, but it's ok
45
        self.decoder.weight_g.requires_grad_(False)
46
    else:
47
      self.decoder = nn.Linear(hidden_dims[-1], in_dim)
48
    self.classifier = nn.Linear(hidden_dims[-1], num_classes)
49
    
50
  def forward(self, x):
51
    out = self.encoder(x)
52
    return self.classifier(out), self.decoder(out)
53
54
55
class MultiviewAE(nn.Module):
56
  r"""Multiview autoencoder. 
57
58
  Args:
59
    in_dims: a list (or iterable) of integers
60
    hidden_dims: a list of ints if every view has the same hidden_dims; otherwise a list of lists of ints
61
    out_dim: for classification, out_dim = num_cls
62
    fuse_type: default 'sum', add up the outputs of all encoders; require all ouputs has the same dimensions
63
      if 'cat', concatenate the outputs of all encoders
64
    dense, residual, residual_layers, nonlinearity, last_nonlinearity, bias are passed to DenseLinear
65
    decoder_norm: if True, add forward prehook torch.nn.utils.weight_norm  to decoder (a nn.Linear module)
66
    decoder_norm_dim: default 0; pass to torch.nn.utils.weight_norm
67
    uniform_decoder_norm: if True, ensure that decoder weight norm is 1 for dim=decoder_norm_dim
68
69
  Shape:
70
    Input: can be a list of tensors or a single tensor which will be splitted into a list
71
    Output: two heads: score matrix of shape (N, out_dim), concatenated decoder output: (N, sum(in_dims))
72
73
  Attributes:
74
    A list of DenseLinear modules as encoders and decoders
75
    An nn.Linear as output layer (e.g., class score matrix)
76
77
  Examples:
78
    >>> x = torch.randn(10, 5)
79
    >>> model = MultiviewAE([2,3], [5, 5], 7)
80
    >>> y = model(x)
81
    >>> y[0].shape, y[1].shape
82
83
  """
84
  def __init__(self, in_dims, hidden_dims, out_dim, fuse_type='sum', dense=False, residual=True, 
85
    residual_layers='all', decoder_norm=False, decoder_norm_dim=0, uniform_decoder_norm=False, 
86
    nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True):
87
    super(MultiviewAE, self).__init__()
88
    self.num_views = len(in_dims)
89
    self.in_dims = in_dims
90
    self.out_dim = out_dim
91
    self.fuse_type = fuse_type
92
    if not isinstance(hidden_dims[0], collections.Iterable):
93
      # hidden_dims is a list of ints, which means all views have the same hidden dims
94
      hidden_dims = [hidden_dims] * self.num_views
95
    self.hidden_dims = hidden_dims
96
    assert len(self.hidden_dims) == self.num_views and isinstance(self.hidden_dims[0], collections.Iterable)
97
    self.encoders = nn.ModuleList()
98
    self.decoders = nn.ModuleList()
99
    for in_dim, hidden_dim in zip(in_dims, hidden_dims):
100
      self.encoders.append(DenseLinear(in_dim, hidden_dim, nonlinearity=nonlinearity, 
101
        last_nonlinearity=last_nonlinearity, dense=dense, forward_input=False, return_all=False, 
102
        return_layers=None, bias=bias, residual=residual, residual_layers=residual_layers))
103
      decoder = nn.Linear(hidden_dim[-1], in_dim)
104
      if decoder_norm:
105
        torch.nn.utils.weight_norm(decoder, 'weight', dim=decoder_norm_dim)
106
        if uniform_decoder_norm:
107
          decoder.weight_g.data = decoder.weight_g.new_ones(decoder.weight_g.size())
108
          decoder.weight_g.requires_grad_(False)
109
      self.decoders.append(decoder)
110
    self.fuse_dims = [hidden_dim[-1] for hidden_dim in self.hidden_dims]
111
    if self.fuse_type == 'sum':
112
      fuse_dim = self.fuse_dims[0]
113
      for d in self.fuse_dims:
114
        assert d == fuse_dim
115
    elif self.fuse_type == 'cat':
116
      fuse_dim = sum(self.fuse_dims)
117
    else:
118
      raise ValueError(f"fuse_type should be 'sum' or 'cat', but is {fuse_type}")
119
    self.output = nn.Linear(fuse_dim, out_dim)
120
121
  def forward(self, xs):
122
    if isinstance(xs, torch.Tensor):
123
      xs = xs.split(self.in_dims, dim=1)
124
    # assert len(xs) == self.num_views
125
    encoder_out = []
126
    decoder_out = []
127
    for i, x in enumerate(xs):
128
      out = self.encoders[i](x)
129
      encoder_out.append(out)
130
      decoder_out.append(self.decoders[i](out))
131
    if self.fuse_type == 'sum':
132
      out = torch.stack(encoder_out, dim=-1).mean(dim=-1)
133
    else:
134
      out = torch.cat(encoder_out, dim=-1)
135
    out = self.output(out)
136
    return out, torch.cat(decoder_out, dim=-1), torch.cat(encoder_out, dim=-1)
137
138
139
def get_interaction_loss(interaction_mat, w, loss_type='graph_laplacian', normalize=True):
140
  """Calculate loss on the inconsistency between feature representations w (N*D) 
141
  and feature interaction network interaction_mat (N*N)
142
  A trivial solution is all features (row vectors of w) have cosine similarity = 1 or distance = 0
143
  
144
  Args:
145
    interaction_mat: non-negative symmetric torch.Tensor with shape (N, N)
146
    w: feature representation tensor with shape (N, D)
147
    normalize: if True, call w = w / w.norm(p=2, dim=1, keepdim=True) /np.sqrt(w.size(0)) 
148
      for loss_type = 'graph_laplacian' or 'dot_product',
149
        this makes sure w.norm() = 1 and the row vectors of w have the same norm: len(torch.unique(w.norm(dim=1)))==1
150
      call loss = loss / w.size(0) for loss_type = 'cosine_similarity'; 
151
      By doing this we ensure the number of features is factored out; 
152
      this is useful for combining losses from multi-views.
153
154
  See Loss_feature_interaction for more documentation
155
156
  """
157
  if loss_type == 'cosine_similarity':
158
    # -(|cos(w,w)| * interaction_mat).sum()
159
    cos = cosine_similarity(w).abs() # get the absolute value of cosine simiarity
160
    loss = -(cos * interaction_mat).sum()
161
    if normalize:
162
      loss = loss / w.size(0)
163
  elif loss_type == 'graph_laplacian':
164
    # trace(w' * L * w)
165
    if normalize:
166
      w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0))
167
      interaction_mat = interaction_mat / interaction_mat.norm() # this will ensure interaction_mat is normalized
168
    diag = torch.diag(interaction_mat.sum(dim=1))
169
    L_interaction_mat = diag - interaction_mat
170
    loss = torch.diagonal(torch.mm(torch.mm(w.t(), L_interaction_mat), w)).sum()
171
  elif loss_type == 'dot_product':
172
    # pairwise distance mat * interaction mat
173
    if normalize:
174
      w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0))
175
    d = torch.sum(w*w, dim=1) # if normalize is True, then d is a vector of the same element 1/w.size(0)
176
    dist = d.unsqueeze(1) + d - 2*torch.mm(w, w.t())
177
    loss = (dist * interaction_mat).sum()
178
    # loss = (dist / dist.norm() * interaction_mat).sum() # This is an alternative to 'normalize' loss
179
  else:
180
    raise ValueError(f"loss_type can only be 'cosine_similarity', "
181
                     f"graph_laplacian' or 'dot_product', but is {loss_type}")
182
  return loss
183
184
185
class Loss_feature_interaction(nn.Module):
186
  r"""A customized loss function for a graph Laplacian constraint on the feature interaction network
187
    For factorization autoencoder model, the decoder weights can be seen as feature representations;
188
    This loss measures the inconsistency between learned feature representations and their interaction network.
189
    A trivial solution is all features have cosine similarity = 1 or distance = 0
190
191
  Args:
192
    interaction_mat: torch.Tensor of shape (N, N), a non-negative (symmetric) matrix; 
193
      or a list of matrices; each is an interaction mat; 
194
      To control the magnitude of the loss, it is preferred to have argument interaction_mat.norm() = 1
195
    loss_type: if loss_type == 'cosine_similarity', calculate -(cos(m, m).abs() * interaction_mat).sum()
196
               if loss_type == 'graph_laplacian' (faster), calculate trace(m' * L * m)
197
               if loss_type == 'dot_product', calculate dist(m) * interaction_mat 
198
                 where dist(m) is the pairwise distance matrix of features; the name 'dot_product' is misleading
199
              If all features have norm 1, all three types are equivalent in a sense
200
              cosine_similarity is preferred because the magnitude of features are implicitly ignored, 
201
               while the other two will be affected by the magnitude of features.
202
    weight_path: default ['decoder', 'weight'], with the goal to get w = model.decoder.weight
203
    normalize: pass it to get_interaction_loss; 
204
      if True, call w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0))
205
        for loss_type 'graph_laplacian' or 'dot_product',
206
          this makes sure each row vector of w has the same norm, and w.norm() = 1
207
        call loss = loss / w.size(0) for loss_type = 'cosine_similarity'; 
208
      By doing this we ensure the number of features is factored out; 
209
      this is useful for combining losses from multi-views.
210
  
211
  Inputs:
212
    model: the above defined AutoEnoder model or other model
213
    or given weight matrix w
214
    if interaction_mat has shape (N,N), then w has shape (N, D)
215
216
  Returns:
217
    loss: torch.Tensor that can call loss.backward()
218
  """
219
220
  def __init__(self, interaction_mat, loss_type='graph_laplacian', weight_path=['decoder', 'weight'], 
221
    normalize=True):
222
    super(Loss_feature_interaction, self).__init__()
223
    self.loss_type = loss_type
224
    self.weight_path = weight_path
225
    self.normalize = normalize
226
    # If interaction_mat is a list, self.sections will be the used for splitting the weight matrix
227
    self.sections = None # when interaction_mat is a single matrix, self.sections is None
228
    if isinstance(interaction_mat, (list, tuple)):
229
      if normalize: # ensure interaction_mat is normalized
230
        interaction_mat = [m/m.norm() for m in interaction_mat]
231
      self.sections = [m.shape[0] for m in interaction_mat]
232
    else:
233
      if normalize: # ensure interaction_mat is normalized
234
        interaction_mat = interaction_mat / interaction_mat.norm()
235
    if self.loss_type == 'graph_laplacian':
236
      # precalculate self.L_interaction_mat save some compute for each forward pass
237
      if self.sections is None:
238
        diag = torch.diag(interaction_mat.sum(dim=1))
239
        self.L_interaction_mat = diag - interaction_mat # Graph Laplacian; should I normalize it?
240
      else:
241
        self.L_interaction_mat = []
242
        for mat in interaction_mat:
243
          diag = torch.diag(mat.sum(dim=1))
244
          self.L_interaction_mat.append(diag - mat)
245
    else: # we don't need to store interaction_mat for loss_type=='graph_laplacian'
246
      self.interaction_mat = interaction_mat
247
  
248
  def forward(self, model=None, w=None):
249
    if w is None:
250
      w = get_attr(model, self.weight_path)
251
    if self.sections is None:
252
      # There is only one interaction matrix; self.interaction_mat is a torch.Tensor
253
      if self.loss_type == 'graph_laplacian':
254
        # Used precalculated L_interaction_mat to save some time
255
        if self.normalize:
256
          # interaction_mat had already been normalized during initialization
257
          w = w / w.norm(p=2, dim=1, keepdim=True) / np.sqrt(w.size(0))
258
        return torch.diagonal(torch.mm(torch.mm(w.t(), self.L_interaction_mat), w)).sum()
259
      else:
260
        return get_interaction_loss(self.interaction_mat, w, loss_type=self.loss_type, normalize=self.normalize)
261
    else:
262
      # self.interaction_mat is a list of torch.Tensors
263
      if isinstance(w, torch.Tensor):
264
        w = w.split(self.sections, dim=0)
265
      if self.loss_type == 'graph_laplacian': # handle 'graph_laplacian' differently to save time during training
266
        loss = 0
267
        for w_, L in zip(w, self.L_interaction_mat):
268
          if self.normalize: # make sure w_.norm() = 1 and each row vector of w_ has the same norm
269
            w_ = w_ / w_.norm(p=2, dim=1, keepdim=True) / np.sqrt(w_.size(0))
270
          loss += torch.diagonal(torch.mm(torch.mm(w_.t(), L), w_)).sum()  
271
        return loss
272
      # for the case 'cosine_similarity' and 'dot_product'
273
      return sum([get_interaction_loss(mat, w_, loss_type=self.loss_type, normalize=self.normalize) 
274
                  for mat, w_ in zip(self.interaction_mat, w)])
275
276
277
class Loss_view_similarity(nn.Module):
278
  r"""The input is a multi-view representation of the same set of patients, 
279
      i.e., a set of matrices with shape (num_samples, feature_dim). feature_dim can be different for each view
280
    This loss will penalize the inconsistency among different views.
281
    This is somewhat limited, because different views should have both shared and complementary information
282
      This loss only encourages the shared information across views, 
283
      which may or may not be good for certain applications.
284
    A trivial solution for this is multi-view representation are all the same; then loss -> -1
285
    The two loss_types 'circle' and 'hub' can be quite different and unstable.
286
      'circle' tries to make all feature representations across views have high cosine similarity,
287
      while 'hub' only tries to make feature representations within each view have high cosine similarity;
288
      by multiplying 'mean-feature' target with 'hub' loss_type, it might 'magically' capture both within-view and 
289
        cross-view similarity; set as default choice; but my limited experimental results do not validate this;
290
        instead, 'circle' and 'hub' are dominant, while explicit_target and cal_target do not make a big difference 
291
    Cosine similarity are used here; To do: other similarity metrics
292
293
  Args:
294
    sections: a list of integers (or an int); this is used to split the input matrix into chunks;
295
      each chunk corresponds to one view representation.
296
      If input xs is not a torch.Tensor, this will not be used; assume xs to be a list of torch.Tensors
297
      sections being an int implies all feature dim are the same, set sections = feature_dim, NOT num_sections!
298
    loss_type: supose there are three views x1, x2, x3; let s_ij = cos(x_i,x_j), s_i = cos(x_i,x_i)
299
      if loss_type=='cicle', similarity = s12*s23*target if fusion_type=='multiply'; s12+s23 if fusion_type=='sum'                   
300
        This is fastest but requires x1, x2, x3 have the same shape
301
      if loss_type=='hub', similarity=s1*s2*s3*target if fusion_type=='multiply'; 
302
        similarity=|s1|+|s2|+|s3|+|target| if fusion_type=='sum'
303
        Implicitly, target=1 (fusion_type=='multiply) or 0 (fusion_type=='sum') if explicit_target is False
304
        if graph_laplacian is False:
305
          loss = - similarity.abs().mean()
306
        else:
307
          s = similarity.abs(); L_s = torch.diag(sum(s, axis=1)) - s #graph laplacian
308
          loss = sum_i(x_i * L_s * x_i^T)
309
    explicit_target: if False, target=1 (fusion_type=='multiply) or 0 (fusion_type=='sum') implicitly
310
      if True, use given target or calculate it from xs
311
      # to do handle the case when we only use the explicitly given target
312
    cal_target: if 'mean-similarity', target = (cos(x1,x1) + cos(x2,x2) + cos(x3,x3))/3
313
                if 'mean-feature', x = (x1+x2+x3)/3; target = cos(x,x); this requires x1,x2,x3 have the same shape
314
    target: default None; only used when explicit_target is True
315
      This saves computation if target is provided in advance or passed as input
316
    fusion_type: if 'multiply', similarity=product(similarities); if 'sum', similarity=sum(|similarities|);
317
      work with loss_type
318
    graph_laplacian:  if graph_laplacian is False:
319
          loss = - similarity.abs().mean()
320
        else:
321
          s = similarity.abs(); L_s = torch.diag(sum(s, axis=1)) - s #graph laplacian
322
          loss = sum_i(x_i * L_s * x_i^T)
323
324
  Inputs:
325
    xs: a set of torch.Tensor matrices of (num_samples, feature_dim), 
326
      or a single matrix with self.sections being specified
327
    target: the target cosine similarity matrix; default None; 
328
      if not given, first check if self.targets is given; 
329
        if self.targets is None, then calulate it according to cal_target;
330
      only used when self.explicit_target is True
331
332
  Output:
333
    loss = -similarity.abs().mean() if graph_laplacian is False # Is this the right way to do it?
334
      = sum_i(x_i * L_s * x_i^T) if graph_laplacian is True # call get_interaction_loss()
335
    
336
  """
337
  def __init__(self, sections=None, loss_type='hub', explicit_target=False, 
338
    cal_target='mean-feature', target=None, fusion_type='multiply', graph_laplacian=False):
339
    super(Loss_view_similarity, self).__init__()
340
    self.sections = sections
341
    if self.sections is not None:
342
      if not isinstance(self.sections, int):
343
        assert len(self.sections) >= 2  
344
    self.loss_type = loss_type
345
    assert self.loss_type in ['circle', 'hub']
346
    self.explicit_target = explicit_target
347
    self.cal_target = cal_target
348
    self.target = target
349
    self.fusion_type = fusion_type
350
    self.graph_laplacian = graph_laplacian
351
    # I got nan losses easily for whenever graph_laplacian is True, especially the following case; did not know why
352
    # probably I need normalize similarity during every forward?
353
    assert not (fusion_type=='multiply' and graph_laplacian) and not (loss_type=='circle' and graph_laplacian)
354
355
  def forward(self, xs, target=None):
356
    if isinstance(xs, torch.Tensor):
357
      # make sure xs is a list of tensors corresponding to multiple views
358
      # this requires self.sections to valid
359
      xs = xs.split(self.sections, dim=1) 
360
    # assert len(xs) >= 2 # comment this to save time for many forward passes
361
    similarity = 1
362
    if self.loss_type == 'circle':
363
      # assert xs[i-1].shape == xs[i].shape
364
      # this saves computation
365
      similarity_mats = [cosine_similarity(xs[i-1], xs[i]) for i in range(1, len(xs))]
366
      similarity_mats = [(m+m.t())/2 for m in similarity_mats] # make it symmetric
367
    elif self.loss_type == 'hub':
368
      similarity_mats = [cosine_similarity(x) for x in xs]
369
    if self.fusion_type=='multiply':
370
      for m in similarity_mats:
371
        similarity = similarity * m # element multiplication ensures the larget value to be 1
372
    elif self.fusion_type=='sum':
373
      similarity = sum(similarity_mats) / len(similarity_mats) # calculate mean to ensure the largest value to be 1
374
375
    if self.explicit_target:
376
      if target is None:
377
        if self.target is None:
378
          if self.cal_target == 'mean-similarity':
379
            target = torch.stack(similarity_mats, dim=0).mean(0)
380
          elif self.cal_target == 'mean-feature':
381
            x = torch.stack(xs, -1).mean(-1) # the list of view matrices must have the same dimension
382
            target = cosine_similarity(x)
383
          else:
384
            raise ValueError(f'cal_target should be mean-similarity or mean-feature, but is {self.cal_target}')
385
        else:
386
          target = self.target
387
      if self.fusion_type=='multiply':
388
        similarity = similarity * target
389
      elif self.fusion_type=='sum':
390
        similarity = (len(similarity_mats)*similarity + target) / (len(similarity_mats) + 1) # Moving average
391
    similarity = similarity.abs() # ensure similarity to be non-negative
392
    if self.graph_laplacian:
393
      # Easily get nan loss when it is True; do not know why
394
      return sum([get_interaction_loss(similarity, w, loss_type='graph_laplacian', normalize=True) for w in xs]) / len(xs)
395
    else:
396
      return -similarity.mean() # to ensure the loss is within range [-1, 0]
397
398