[286bfb]: / src / mil_models / model_linear.py

Download this file

137 lines (115 with data), 5.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import pdb
from .components import create_mlp, create_mlp_with_dropout, process_surv, process_clf
class LinearEmb(nn.Module):
"""
Linear fully-connected layer from slide representation to output
"""
def __init__(self, config, mode):
super().__init__()
self.config = config
self.classifier = nn.Linear(config.in_dim, config.n_classes, bias=False)
self.n_classes = config.n_classes
self.mode = mode
def forward_no_loss(self, h, attn_mask=None):
logits = self.classifier(h)
out = {'logits': logits}
return out
def forward(self, h, model_kwargs={}):
if self.mode == 'classification':
attn_mask = model_kwargs['attn_mask']
label = model_kwargs['label']
loss_fn = model_kwargs['loss_fn']
out = self.forward_no_loss(h, attn_mask=attn_mask)
logits = out['logits']
results_dict, log_dict = process_clf(logits, label, loss_fn)
elif self.mode == 'survival':
attn_mask = model_kwargs['attn_mask']
label = model_kwargs['label']
censorship = model_kwargs['censorship']
loss_fn = model_kwargs['loss_fn']
out = self.forward_no_loss(h, attn_mask=attn_mask)
logits = out['logits']
results_dict, log_dict = process_surv(logits, label, censorship, loss_fn)
else:
raise NotImplementedError("Not Implemented!")
return results_dict, log_dict
#
# MLP per prototype
#
class IndivMLPEmb(nn.Module):
"""
Comprised of three MLP (in sequence), each of which can be enabled/disabled and configured accordingly
- Shared: Shared MLP across prototypes for feature dimension reduction
- Indiv: Individual MLP per prototype
- Post: Shared MLP across prototypes for final feature dimension reduction
"""
def __init__(self, config):
super().__init__()
self.config = config
self.n_classes = config.n_classes
self.p = config.p
mlp_func = create_mlp_with_dropout
if config.shared_mlp:
self.shared_mlp = mlp_func(in_dim=config.in_dim,
hid_dims=[config.shared_embed_dim] *
(config.n_fc_layers - 1),
dropout=config.shared_dropout,
out_dim=config.shared_embed_dim,
end_with_fc=False)
next_in_dim = config.shared_embed_dim
else:
self.shared_mlp = nn.Identity()
next_in_dim = config.in_dim
if config.indiv_mlps:
self.indiv_mlps = nn.ModuleList([mlp_func(in_dim=next_in_dim,
hid_dims=[config.indiv_embed_dim] *
(config.n_fc_layers - 1),
dropout=config.indiv_dropout,
out_dim=config.indiv_embed_dim,
end_with_fc=False) for i in range(config.p)])
next_in_dim = config.p * config.indiv_embed_dim
else:
self.indiv_mlps = nn.ModuleList([nn.Identity() for i in range (config.p)])
next_in_dim = config.p * next_in_dim
if config.postcat_mlp:
self.postcat_mlp = mlp_func(in_dim=next_in_dim,
hid_dims=[config.postcat_embed_dim] *
(config.n_fc_layers - 1),
dropout=config.postcat_dropout,
out_dim=config.postcat_embed_dim,
end_with_fc=False)
next_in_dim = config.postcat_embed_dim
else:
self.postcat_mlp = nn.Identity()
self.classifier = nn.Linear(next_in_dim,
config.n_classes,
bias=False)
def forward_no_loss(self, h, attn_mask=None):
h = self.shared_mlp(h)
h = torch.stack([self.indiv_mlps[idx](h[:, idx, :]) for idx in range(self.p)], dim=1)
h = h.reshape(h.shape[0], -1) # (n_samples, n_proto * config.indiv_embed_dim)
h = self.postcat_mlp(h)
logits = self.classifier(h)
out = {'logits': logits}
return out
def forward(self, h, model_kwargs={}):
if self.mode == 'classification':
attn_mask = model_kwargs['attn_mask']
label = model_kwargs['label']
loss_fn = model_kwargs['loss_fn']
out = self.forward_no_loss(h, attn_mask=attn_mask)
logits = out['logits']
results_dict, log_dict = process_clf(logits, label, loss_fn)
elif self.mode == 'survival':
attn_mask = model_kwargs['attn_mask']
label = model_kwargs['label']
censorship = model_kwargs['censorship']
loss_fn = model_kwargs['loss_fn']
out = self.forward_no_loss(h, attn_mask=attn_mask)
logits = out['logits']
results_dict, log_dict = process_surv(logits, label, censorship, loss_fn)
else:
raise NotImplementedError("Not Implemented!")
return results_dict, log_dict