[286bfb]: / src / mil_models / model_abmil.py

Download this file

105 lines (82 with data), 3.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pdb
from .components import Attn_Net, Attn_Net_Gated, create_mlp, process_surv, process_clf
from .model_configs import ABMILConfig
class ABMIL(nn.Module):
def __init__(self, config, mode):
super().__init__()
self.config = config
self.mlp = create_mlp(in_dim=config.in_dim,
hid_dims=[config.embed_dim] *
(config.n_fc_layers - 1),
dropout=config.dropout,
out_dim=config.embed_dim,
end_with_fc=False)
if config.gate:
self.attention_net = Attn_Net_Gated(L=self.config.embed_dim,
D=config.attn_dim,
dropout=config.dropout,
n_classes=1)
else:
self.attention_net = Attn_Net(L=config.embed_dim,
D=config.attn_dim,
dropout=config.dropout,
n_classes=1)
self.classifier = nn.Linear(config.embed_dim, config.n_classes)
self.n_classes = config.n_classes
self.mode = mode
def forward_attention(self, h, attn_only=False):
# B: batch size
# N: number of instances per WSI
# L: input dimension
# K: number of attention heads (K = 1 for ABMIL)
# h is B x N x L
h = self.mlp(h)
# h is B x N x D
A = self.attention_net(h) # B x N x K
A = torch.transpose(A, -2, -1) # B x K x N
if attn_only:
return A
else:
return h, A
def forward_no_loss(self, h, attn_mask=None):
h, A = self.forward_attention(h)
A_raw = A
# A is B x K x N
if attn_mask is not None:
A = A + (1 - attn_mask).unsqueeze(dim=1) * torch.finfo(A.dtype).min
A = F.softmax(A, dim=-1) # softmax over N
M = torch.bmm(A, h).squeeze(dim=1) # B x K x C --> B x C
logits = self.classifier(M)
out = {'logits': logits, 'attn': A, 'feats': h, 'feats_agg': M}
return out
def forward(self, h, model_kwargs={}):
if self.mode == 'classification':
attn_mask = model_kwargs['attn_mask']
label = model_kwargs['label']
loss_fn = model_kwargs['loss_fn']
out = self.forward_no_loss(h, attn_mask=attn_mask)
logits = out['logits']
results_dict, log_dict = process_clf(logits, label, loss_fn)
elif self.mode == 'survival':
attn_mask = model_kwargs['attn_mask']
label = model_kwargs['label']
censorship = model_kwargs['censorship']
loss_fn = model_kwargs['loss_fn']
out = self.forward_no_loss(h, attn_mask=attn_mask)
logits = out['logits']
results_dict, log_dict = process_surv(logits, label, censorship, loss_fn)
else:
raise NotImplementedError("Not Implemented!")
return results_dict, log_dict
# class ABMILSurv(ABMIL):
# def __init__(self, config: ABMILConfig):
# super().__init__(config)
# def forward(self, h, attn_mask=None, label=None, censorship=None, loss_fn=None):
# out = self.forward_no_loss(h, attn_mask=attn_mask)
# logits = out['logits']
# results_dict, log_dict = process_surv(logits, label, censorship, loss_fn)
# return results_dict, log_dict