Switch to unified view

a b/survival4D/nn/torch/models.py
1
from typing import Tuple
2
import numpy as np
3
import torch
4
from torch.utils.data import TensorDataset, DataLoader
5
6
7
class TorchModel(torch.nn.Module):
8
    def predict(self, x: np.ndarray, batch_size: int = 1) -> Tuple[np.ndarray, np.ndarray]:
9
        x = torch.from_numpy(x).cuda().float()
10
        dataset = TensorDataset(x)
11
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
12
        self.eval()
13
        decodeds = []
14
        risks = []
15
        for x in dataloader:
16
            decoded, risk_pred = self(x[0])
17
            decodeds.append(decoded)
18
            risks.append(risk_pred)
19
        decodeds = torch.cat(decodeds, dim=0)
20
        decodeds = decodeds.cpu().detach().numpy()
21
        risks = torch.cat(risks, dim=0)
22
        risks = risks.cpu().detach().numpy()
23
        return decodeds, risks
24
25
26
class BaselineAutoencoder(TorchModel):
27
    def __init__(self, input_shape: int, dropout: float, num_ae_units1: int, num_ae_units2: int):
28
        super().__init__()
29
        num_ae_units1 = round(num_ae_units1)
30
        num_ae_units2 = round(num_ae_units2)
31
        self.encoder = torch.nn.Sequential(
32
            torch.nn.Dropout(dropout),
33
            torch.nn.Linear(input_shape, num_ae_units1),
34
            torch.nn.ReLU(),
35
            torch.nn.Linear(num_ae_units1, num_ae_units2),
36
            torch.nn.ReLU(),
37
        )
38
39
        self.decoder = torch.nn.Sequential(
40
            torch.nn.Linear(num_ae_units2, num_ae_units1),
41
            torch.nn.ReLU(),
42
            torch.nn.Linear(num_ae_units1, input_shape)
43
        )
44
        self.risk_regressor = torch.nn.Linear(num_ae_units2, 1)
45
46
    def forward(self, x):
47
        encoded = self.encoder(x)
48
        risk_pred = self.risk_regressor(encoded)
49
        decoded = self.decoder(encoded)
50
        return decoded, risk_pred
51
52
53
class BaselineBNAutoencoder(TorchModel):
54
    def __init__(self, input_shape: int, dropout: float, num_ae_units1: int, num_ae_units2: int):
55
        super().__init__()
56
        num_ae_units1 = round(num_ae_units1)
57
        num_ae_units2 = round(num_ae_units2)
58
        self.encoder = torch.nn.Sequential(
59
            torch.nn.Dropout(dropout),
60
            torch.nn.Linear(input_shape, num_ae_units1),
61
            torch.nn.BatchNorm1d(num_ae_units1),
62
            torch.nn.ReLU(),
63
            torch.nn.Linear(num_ae_units1, num_ae_units2),
64
            torch.nn.BatchNorm1d(num_ae_units2),
65
            torch.nn.ReLU(),
66
        )
67
68
        self.decoder = torch.nn.Sequential(
69
            torch.nn.Linear(num_ae_units2, num_ae_units1),
70
            torch.nn.BatchNorm1d(num_ae_units1),
71
            torch.nn.ReLU(),
72
            torch.nn.Linear(num_ae_units1, input_shape)
73
        )
74
        self.risk_regressor = torch.nn.Sequential(
75
            torch.nn.Linear(num_ae_units2, num_ae_units2//2),
76
            torch.nn.BatchNorm1d(num_ae_units2//2),
77
            torch.nn.ReLU(),
78
            torch.nn.Linear(num_ae_units2//2, 1)
79
        )
80
81
    def forward(self, x):
82
        encoded = self.encoder(x)
83
        risk_pred = self.risk_regressor(encoded)
84
        decoded = self.decoder(encoded)
85
        return decoded, risk_pred
86
87
88
def model_factory(model_name: str, **kwargs) -> TorchModel:
89
    # Before defining network architecture, clear current computation graph (if one exists)
90
    torch.cuda.empty_cache()
91
    if model_name == "baseline_autoencoder":
92
        model = BaselineAutoencoder(**kwargs)
93
    elif model_name == "baseline_bn_autoencoder":
94
        model = BaselineBNAutoencoder(**kwargs)
95
    else:
96
        raise ValueError("Model name {} has not been implemented.".format(model_name))
97
    return model