Diff of /Classifier/AllPytorch.py [000000] .. [f757a9]

Switch to unified view

a b/Classifier/AllPytorch.py
1
import matplotlib.pyplot as plt
2
import pytorch_lightning as pl
3
from pytorch_lightning.callbacks import EarlyStopping
4
from pytorch_lightning.metrics.functional import accuracy
5
from pytorch_lightning.callbacks import LearningRateMonitor
6
from torch.optim.lr_scheduler import OneCycleLR
7
import torch
8
import torch.nn as nn
9
import os
10
import numpy as np
11
import pandas as pd
12
from torch.utils.data import DataLoader
13
from torch.utils.data.sampler import SubsetRandomSampler
14
import os
15
import torch.nn.functional as F
16
from torch.utils.data import Dataset
17
from torchvision import transforms, datasets
18
from torch.optim import Adam
19
import random
20
from sklearn.model_selection import train_test_split
21
from sklearn.metrics import confusion_matrix, classification_report
22
import seaborn as sns
23
import PIL.Image as Image
24
25
import json
26
import numpy as np
27
from matplotlib.colors import LinearSegmentedColormap
28
29
from captum.attr import IntegratedGradients
30
from captum.attr import GradientShap
31
from captum.attr import NoiseTunnel
32
from captum.attr import Saliency
33
from captum.attr import visualization as viz
34
35
36
SEED = 323
37
def seed_everything(seed=SEED):
38
    random.seed(seed)
39
    os.environ['PYHTONHASHSEED'] = str(seed)
40
    np.random.seed(seed)
41
    torch.manual_seed(seed)
42
    torch.cuda.manual_seed(seed)
43
    torch.backends.cudnn.deterministic = True
44
45
#decay
46
decay = 5e-4
47
# training batch size
48
batch_size = 10
49
# check if cuda is available: if not available, then use cpu
50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
52
53
class LeukemiaDataset(Dataset):
54
55
    """
56
    Acute Lymphoblastic Leukemia Dataset Reader.
57
58
     Args:
59
           df_data: Dataframe for CSV file
60
           data_dir: path to Lymphoblastic Leukemia Data
61
           transform: transforms for performing data augmentation
62
    """
63
64
    def __init__(self, df_data, data_dir='./', transform=None):
65
        super().__init__()
66
        self.df = df_data.values
67
        self.data_dir = data_dir
68
        self.transform = transform
69
70
    def __len__(self):
71
        return len(self.df)
72
73
    def __getitem__(self, index):
74
        img_name, label = self.df[index]
75
        img_path = os.path.join(self.data_dir, img_name + '.jpg')
76
        image = Image.open(img_path)
77
        if self.transform is not None:
78
            image = self.transform(image)
79
        return image, label
80
81
82
def augmentation():
83
    """Acute Lymphoblastic Leukemia data augmentation"""
84
    mean, std_dev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
85
    training_transforms = transforms.Compose([transforms.Resize((100, 100)),
86
                                              transforms.RandomRotation(30),
87
                                              transforms.RandomResizedCrop(100),
88
                                              transforms.RandomHorizontalFlip(),
89
                                              transforms.RandomGrayscale(p=0.1),
90
                                              transforms.ToTensor(),
91
                                              transforms.Normalize(mean=mean, std=std_dev)])
92
93
    validation_transforms = transforms.Compose([
94
        transforms.Resize((100, 100)),
95
        transforms.ToTensor(),
96
        transforms.Normalize(mean=mean, std=std_dev)])
97
98
    return training_transforms, validation_transforms
99
100
101
102
103
def predict_probability(model, transforms, image_path, use_gpu=True):
104
    img = Image.open(image_path)
105
    img = img.convert('RGB')
106
    img = transforms(img).unsqueeze(0)
107
    if use_gpu:
108
        image = img.cuda()
109
110
    predictions = model(image)
111
    predictions = torch.sigmoid(predictions)
112
    predictions = predictions.detach().cpu().numpy().flatten()
113
    return predictions
114
115
116
def show_prediction_confidence(prediction, class_names):
117
    pred_df = pd.DataFrame({
118
        'class_names': class_names,
119
        'values': prediction
120
    })
121
    sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
122
    sns
123
    plt.xlim([0, 1])
124
125
def get_predictions(model, data_loader, use_gpu=True):
126
    model = model.eval()
127
    y_predictions = list()
128
    y_true = list()
129
    with torch.no_grad():
130
        for i, dataset in enumerate(data_loader):
131
            inputs, labels = dataset
132
            inputs, labels = inputs.to(device), labels.to(device)
133
134
        outputs = model(inputs)
135
        _, preds = torch.max(outputs, 1)
136
        y_predictions.extend(preds)
137
        y_true.extend(labels)
138
139
    predictions = torch.as_tensor(y_predictions).cpu()
140
    y_true = torch.as_tensor(y_true).cpu()
141
    return predictions, y_true
142
143
144
def confusion_matrix2(confusion_matrix, class_names, save_path):
145
    cm = confusion_matrix.copy()
146
147
    cell_counts = cm.flatten()
148
149
    cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]
150
151
    row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]
152
153
    cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]
154
    cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])
155
156
    df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)
157
158
    hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")
159
    hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
160
    hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
161
    plt.ylabel('True diagnostic')
162
    plt.xlabel('Predicted diagnostic')
163
    plt.savefig(save_path)
164
    plt.show()
165
166
167
color = [(0, '#ffffff'), (0.25, '#000000'), (1, '#000000')]
168
name = 'custom blue'
169
N = 256
170
171
def linear_seg_color_map(name, color, N, gamma=1.0):
172
    """
173
    Render color map based on lookup tables
174
    :param name: name the color
175
    :param color: color code
176
    :param N: number of RGB quantization
177
    :param gamma: default is 1.0
178
    :return:
179
    """
180
    default_cmap = LinearSegmentedColormap.from_list(name, color, N, gamma)
181
    return default_cmap
182
183
def predict(model, transforms, image_path, use_cpu):
184
    """
185
186
    :param model:
187
    :param transforms: data transform
188
    :param image_path: inference input image path
189
    :param use_cpu:
190
    :return:
191
    """
192
    model.cpu()
193
    model = model.eval()
194
    img = Image.open(image_path)
195
    img = img.convert('RGB')
196
    transformed_img = transforms(img)
197
    image = transformed_img
198
    image = image.unsqueeze(0)
199
    if use_cpu:
200
        image = image.cpu()
201
    elif use_cpu == False:
202
        model.cuda()
203
        image = image.cuda
204
    output = model(image)
205
206
    output = torch.softmax(output, dim=1)
207
    prediction_score, pred_label_idx = torch.topk(output, 1)
208
    return image, transformed_img, prediction_score, pred_label_idx
209
210
def interpret_model(model, transforms, image_path, label_path, use_cpu=True, interpret_type=""):
211
212
    """
213
    :param model: our model
214
    :param transforms: Data transformation
215
    :param image_path: Image directory
216
    :param label_path: Json label directory
217
    :param use_gpu: set gpu to True
218
    :param interpret_type: mode for model interpretability: "integrated gradients"
219
    for Integrated Gradients, "gradient shap" for Gradient Shap and "occlusion" for Occlusion
220
    :return:
221
    """
222
    with open(label_path) as json_data:
223
        idx_to_labels = json.load(json_data)
224
225
    # Check if mode is Integrated Gradients
226
    if interpret_type == "integrated gradients":
227
228
        print('Performing Integrated Gradients Model Interpretation', interpret_type)
229
        image, transformed_img, prediction_score, pred_label_idx = predict(model, transforms, image_path, use_cpu)
230
        pred_label_idx.squeeze_()
231
        predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
232
        print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
233
234
        integrated_gradients = IntegratedGradients(model)
235
        attributions_ig = integrated_gradients.attribute(image, target=pred_label_idx, n_steps=20)
236
237
        _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)),
238
                                              np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
239
                                              ["original_image", "heat_map"],
240
                                              ["all", "absolute_value"],
241
                                              cmap=linear_seg_color_map(name, color, N),
242
                                              show_colorbar=True)
243
244
    # Check if mode is Integrated Gradients with Noise Tunnel
245
    elif interpret_type == "integrated gradient noise":
246
        print('Performing Integrated Gradients Noise Tunnel Model Interpretation', interpret_type)
247
        image, transformed_img, prediction_score, pred_label_idx = predict(model, transforms, image_path, use_cpu)
248
        pred_label_idx.squeeze_()
249
        predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
250
        print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
251
252
        integrated_gradients = IntegratedGradients(model)
253
        noise_tunnel = NoiseTunnel(integrated_gradients)
254
255
        attributions_ig_nt = noise_tunnel.attribute(image, n_samples=10, nt_type='smoothgrad_sq', target=pred_label_idx)
256
        _ = viz.visualize_image_attr_multiple(
257
            np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1, 2, 0)),
258
            np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
259
            ["original_image", "heat_map"],
260
            ["all", "positive"],
261
            cmap=linear_seg_color_map(name, color, N),
262
            show_colorbar=True)
263
264
265
266
    # Check if mode is Gradient Shap
267
    elif interpret_type == "gradient shap":
268
269
        print('Performing Gradient Shap Model Interpretation', interpret_type)
270
        image, transformed_img, prediction_score, pred_label_idx = predict(model,transforms, image_path, use_cpu)
271
        pred_label_idx.squeeze_()
272
        predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
273
        print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
274
        gradient_shap = GradientShap(model)
275
276
        # Defining baseline distribution of images
277
        rand_img_dist = torch.cat([image * 0, image * 1])
278
279
        attributions_gs = gradient_shap.attribute(image,
280
                                                  n_samples=50,
281
                                                  stdevs=0.0001,
282
                                                  baselines=rand_img_dist,
283
                                                  target=pred_label_idx)
284
        _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1, 2, 0)),
285
                                              np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
286
                                              ["original_image", "heat_map"],
287
                                              ["all", "absolute_value"],
288
                                              cmap=linear_seg_color_map(name, color, N),
289
                                              show_colorbar=True)
290
291
        # Check if mode is Saliency
292
    elif interpret_type == "saliency":
293
294
        print('Performing Saliency Model Interpretation', interpret_type)
295
        image, transformed_img, prediction_score, pred_label_idx = predict(model, transforms, image_path, use_cpu)
296
        pred_label_idx.squeeze_()
297
        predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
298
        print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
299
        saliency = Saliency(model)
300
        attributions_gs = saliency.attribute(image,target=pred_label_idx)
301
        _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1, 2, 0)),
302
                                              np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)),
303
                                              ["original_image", "heat_map"],
304
                                              ["all", "absolute_value"],
305
                                              cmap=linear_seg_color_map(name, color, N),
306
                                              show_colorbar=True)
307
308
309
310
class LuekemiaNet(pl.LightningModule):
311
    def __init__(self, lr=0.01, weight_decay=5e-4):
312
        super(LuekemiaNet, self).__init__()
313
314
        self.save_hyperparameters()
315
        self.conv1 = nn.Sequential(
316
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=0),
317
            nn.ReLU(),
318
            nn.BatchNorm2d(32),
319
            nn.Dropout(p=0.25),
320
            nn.AvgPool2d(2))
321
322
        self.conv2 = nn.Sequential(
323
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=1),
324
            nn.ReLU(),
325
            nn.BatchNorm2d(64),
326
            nn.Dropout(p=0.25),
327
            nn.AvgPool2d(2))
328
329
        self.conv3 = nn.Sequential(
330
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1),
331
            nn.ReLU(),
332
            nn.BatchNorm2d(128),
333
            nn.Dropout(p=0.25),
334
            nn.AvgPool2d(2))
335
        
336
337
        self.fc = nn.Sequential(
338
            nn.Linear(128 * 10 * 10, 200),
339
            nn.ReLU(),
340
            nn.Dropout(),
341
            nn.Linear(200, 2))
342
343
    def forward(self, x):
344
        """Method for Forward Prop"""
345
        out = self.conv1(x)
346
        out = self.conv2(out)
347
        out = self.conv3(out)
348
        ######################
349
        # For model debugging #
350
        ######################
351
        #print(out.shape)
352
353
        out = out.view(x.shape[0], -1)
354
        out = self.fc(out)
355
        return out
356
    
357
    def configure_optimizers(self):
358
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
359
        return optimizer
360
   
361
    def training_step(self, batch, batch_idx):
362
        x, y = batch
363
        logits = self(x)
364
        loss = F.cross_entropy(logits, y)
365
        self.log('train_loss', loss)
366
        return loss
367
368
    def evaluate(self, batch, stage=None):
369
        x, y = batch
370
        logits = self(x)
371
        loss = F.cross_entropy(logits, y)
372
        preds = torch.argmax(logits, dim=1)
373
        acc = accuracy(preds, y)
374
375
        if stage:
376
            self.log(f'{stage}_loss', loss, prog_bar=True)
377
            self.log(f'{stage}_acc', acc, prog_bar=True)
378
379
    def validation_step(self, batch, batch_idx):
380
        self.evaluate(batch, 'val')
381
382
    def test_step(self, batch, batch_idx):
383
        self.evaluate(batch, 'test')
384
    
385
    
386
387
388
image_path = '/home/allen/Drive C/Peter Moss AML Leukemia Research/Dataset/all_test/Im041_0.jpg'
389
label_idx = '/home/allen/Drive C/Peter Moss AML Leukemia Research/ALL-PyTorch-2020/Classifier/Model/class_idx.json'
390
391
392
seed_everything(SEED)
393
# train data directory
394
train_dir = '/home/allen/Drive C/Peter Moss AML Leukemia Research/Dataset/all_train/'
395
# train label directoy
396
train_csv = '/home/allen/Drive C/Peter Moss AML Leukemia Research/Dataset/train.csv'
397
# labels
398
class_name = ["zero", "one"]
399
400
# number of epoch
401
epochs = 20
402
# learning rate
403
learn_rate = 0.001
404
# read train CSV file
405
406
labels = pd.read_csv(train_csv)
407
# print label count
408
labels_count = labels.label.value_counts()
409
print(labels_count)
410
# print 5 label header
411
print(labels.head())
412
# splitting data into training and validation set
413
train, valid = train_test_split(labels, stratify = labels.label, test_size = 0.1, shuffle=True)
414
print(len(train),len(valid))
415
#data augmentation
416
training_transforms, validation_transforms = augmentation()
417
418
train_dataset = LeukemiaDataset(df_data=train, data_dir=train_dir, transform=training_transforms)
419
valid_dataset = LeukemiaDataset(df_data=valid, data_dir=train_dir, transform=validation_transforms)
420
train_sampler = SubsetRandomSampler(list(train.index))
421
valid_sampler = SubsetRandomSampler(list(valid.index))
422
# Prepare dataset for neural networks
423
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
424
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
425
426
model_path = "weight2.pth"
427
model = LuekemiaNet()
428
early_stop_callback = EarlyStopping(
429
    monitor='val_acc',
430
    min_delta=0.00,
431
    patience=3,
432
    verbose=False,
433
    mode='max'
434
)
435
trainer = pl.Trainer(max_epochs=epochs, log_every_n_steps=2, callbacks=[early_stop_callback])
436
trainer.fit(model, train_data_loader, valid_data_loader)
437
trainer.save_checkpoint(model_path)
438
439
real_model = model.load_from_checkpoint(model_path)
440
441
y_pred, y_test = get_predictions(real_model, valid_data_loader)
442
# Get model precision, recall and f1_score
443
print(classification_report(y_test, y_pred, target_names=class_name))
444
# Get model confusion matrix
445
cm = confusion_matrix(y_test, y_pred)
446
confusion_matrix2(cm, class_name,save_path='confusion_matrix.png')
447
448
#prediction = predict_probability(real_model, validation_transforms, image_path)
449
interpret_model(real_model, validation_transforms, image_path, label_idx, use_cpu=True, interpret_type="integrated gradients")
450
interpret_model(real_model, validation_transforms, image_path, label_idx, use_cpu=True, interpret_type="gradient shap")
451
interpret_model(real_model, validation_transforms, image_path, label_idx, use_cpu=True, interpret_type="saliency")
452
453