Switch to unified view

a b/Classifier/Classes/Model.py
1
import numpy as np
2
from sklearn import metrics
3
from sklearn.metrics import roc_auc_score, accuracy_score
4
import torch
5
import torch.nn as nn
6
import matplotlib.pyplot as plt
7
from Classifier.Classes.Helpers import Helpers
8
9
helper = Helpers("Model", False)
10
11
class LuekemiaNet(nn.Module):
12
    """ Model Class
13
14
        Model functions for the Acute Lymphoblastic Leukemia Pytorch CNN 2020.
15
        """
16
    def __init__(self):
17
        super(LuekemiaNet, self).__init__()
18
        self.conv1 = nn.Sequential(
19
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=0),
20
            nn.BatchNorm2d(32),
21
            nn.ReLU(),
22
            nn.AdaptiveAvgPool2d(2))
23
24
        self.conv2 = nn.Sequential(
25
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
26
            nn.BatchNorm2d(64),
27
            nn.ReLU(),
28
            nn.AdaptiveAvgPool2d(2))
29
30
        self.conv3 = nn.Sequential(
31
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
32
            nn.BatchNorm2d(128),
33
            nn.ReLU(),
34
            nn.AdaptiveAvgPool2d(2))
35
36
        self.fc = nn.Sequential(
37
            nn.Linear(128*2*2, 1024),
38
            nn.ReLU(inplace=True),
39
            nn.Dropout(0.2),
40
            nn.Linear(1024, 512),
41
            nn.Dropout(0.2),
42
            nn.Linear(512, 2))
43
44
    def forward(self, x):
45
        """Method for Forward Prop"""
46
        out = self.conv1(x)
47
        out = self.conv2(out)
48
        out = self.conv3(out)
49
50
        ######################
51
        # For model debugging #
52
        ######################
53
        #print(out.shape)
54
55
        out = out.view(x.shape[0], -1)
56
        out = self.fc(out)
57
        return out
58
59
60
61
###################
62
# Train model     #
63
###################
64
def learn(model, loss_function, optimizer, epochs, train_loader,
65
          valid_loader, scheduler, train_on_gpu=True):
66
    """ """
67
68
    trainset_auc = list()
69
    train_losses = list()
70
    valid_losses = list()
71
    val_auc = list()
72
    valid_auc_epoch = list()
73
    train_auc_epoch = list()
74
75
    for epoch in range(epochs):
76
        train_loss = 0.0
77
        valid_loss = 0.0
78
        model.train_model()
79
        for i, data in enumerate(train_loader):
80
            inputs, labels = data
81
            if train_on_gpu:
82
                inputs, labels = inputs.cuda(), labels.cuda()
83
            # first of all, set gradient of all optimized variables to zero
84
            optimizer.zero_grad()
85
            # forward pass: compute predicted outputs by passing inputs data to the model
86
            outputs = model(inputs)
87
            # compute batch loss(binary cross entropy
88
            loss = loss_function(outputs[:,0], labels.float())
89
            # Update Train loss and accuracies
90
            train_loss += loss.item() * inputs.size(0)
91
            # backward pass: compute gradient of the loss with respect to model parameters
92
            loss.backward()
93
            # perform one step of optimization
94
            optimizer.step()
95
            scheduler.step()
96
            y_actual = labels.data.cpu().numpy()
97
            prediction = outputs[:, 0].detach().cpu().numpy()
98
            trainset_auc.append(roc_auc_score(y_true=y_actual, y_score=prediction))
99
100
        model.eval()
101
        for _, data in enumerate(valid_loader):
102
            inputs, labels = data
103
            # if CUDA GPU is available
104
            if train_on_gpu:
105
                inputs, labels = inputs.cuda(), labels.cuda()
106
                # first of all, set gradient of all optimized variables to zero
107
            optimizer.zero_grad()
108
            # forward pass: compute predicted outputs by passing inputs data to the model
109
            outputs = model(inputs)
110
            # compute batch loss(binary cross entropy
111
            loss = loss_function(outputs[:,0], labels.float())
112
            # update average validation loss
113
            valid_loss += loss.item() * inputs.size(0)
114
            y_actual = labels.data.cpu().numpy()
115
            predictionx = outputs[:, 0].detach().cpu().numpy()
116
            val_auc.append(roc_auc_score(y_actual, predictionx))
117
118
            # calculate average losses
119
        train_loss = train_loss / len(train_loader.sampler)
120
        valid_loss = valid_loss / len(valid_loader.sampler)
121
        valid_auc = np.mean(val_auc)
122
        train_auc = np.mean(trainset_auc)
123
        train_auc_epoch.append(np.mean(trainset_auc))
124
        valid_auc_epoch.append(np.mean(val_auc))
125
        train_losses.append(train_loss)
126
        valid_losses.append(valid_loss)
127
128
        helper.logger.info(
129
            'Epoch: {} | Training Loss: {:.6f} | Training AUC: {:.6f}| Validation Loss: {:.6f} | Validation AUC: {:.4f}'.format(
130
                    epoch, train_loss, train_auc, valid_loss, valid_auc))
131
132
    torch.save(model.state_dict(), helper.config["classifier"]["model_params"]["weights"])
133
134
    plt.plot(train_losses, label='Training loss')
135
    plt.plot(valid_losses, label='Validation loss')
136
    plt.xlabel("Epochs")
137
    plt.ylabel("Loss")
138
    plt.legend(frameon=False)
139
    plt.savefig(helper.config["classifier"]["model_params"]["plot_loss"])
140
    plt.show()
141
142
    plt.plot(valid_auc_epoch, label='Validation AUC/Epochs')
143
    plt.plot(train_auc_epoch, label='Training AUC/Epochs')
144
    plt.legend("")
145
    plt.xlabel("Epochs")
146
    plt.ylabel("Area Under the Curve")
147
    plt.legend(frameon=False)
148
    plt.savefig(helper.config["classifier"]["model_params"]["plot_auc"])
149
    plt.show()
150