[e66fb7]: / src / mil_models / model_PANTHER.py

Download this file

56 lines (45 with data), 1.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Model initiation for PANTHER
from torch import nn
import numpy as np
from .components import predict_surv, predict_clf, predict_emb
from .PANTHER.layers import PANTHERBase
from utils.proto_utils import check_prototypes
class PANTHER(nn.Module):
"""
Wrapper for PANTHER model
"""
def __init__(self, config, mode):
super(PANTHER, self).__init__()
self.config = config
emb_dim = config.in_dim
self.emb_dim = emb_dim
self.heads = config.heads
self.outsize = config.out_size
self.load_proto = config.load_proto
self.mode = mode
check_prototypes(config.out_size, self.emb_dim, self.load_proto, config.proto_path)
# This module contains the EM step
self.panther = PANTHERBase(self.emb_dim, p=config.out_size, L=config.em_iter,
tau=config.tau, out=config.out_type, ot_eps=config.ot_eps,
load_proto=config.load_proto, proto_path=config.proto_path,
fix_proto=config.fix_proto)
def representation(self, x):
"""
Construct unsupervised slide representation
"""
out, qqs = self.panther(x)
return {'repr': out, 'qq': qqs}
def forward(self, x):
out = self.representation(x)
return out['repr']
def predict(self, data_loader, use_cuda=True):
if self.mode == 'classification':
output, y = predict_clf(self, data_loader.dataset, use_cuda=use_cuda)
elif self.mode == 'survival':
output, y = predict_surv(self, data_loader.dataset, use_cuda=use_cuda)
elif self.mode == 'emb':
output = predict_emb(self, data_loader.dataset, use_cuda=use_cuda)
y = None
else:
raise NotImplementedError(f"Not implemented for {self.mode}!")
return output, y