|
a |
|
b/Classifier/AllPytorch.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import pytorch_lightning as pl |
|
|
3 |
from pytorch_lightning.callbacks import EarlyStopping |
|
|
4 |
from pytorch_lightning.metrics.functional import accuracy |
|
|
5 |
from pytorch_lightning.callbacks import LearningRateMonitor |
|
|
6 |
from torch.optim.lr_scheduler import OneCycleLR |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import os |
|
|
10 |
import numpy as np |
|
|
11 |
import pandas as pd |
|
|
12 |
from torch.utils.data import DataLoader |
|
|
13 |
from torch.utils.data.sampler import SubsetRandomSampler |
|
|
14 |
import os |
|
|
15 |
import torch.nn.functional as F |
|
|
16 |
from torch.utils.data import Dataset |
|
|
17 |
from torchvision import transforms, datasets |
|
|
18 |
from torch.optim import Adam |
|
|
19 |
import random |
|
|
20 |
from sklearn.model_selection import train_test_split |
|
|
21 |
from sklearn.metrics import confusion_matrix, classification_report |
|
|
22 |
import seaborn as sns |
|
|
23 |
import PIL.Image as Image |
|
|
24 |
|
|
|
25 |
import json |
|
|
26 |
import numpy as np |
|
|
27 |
from matplotlib.colors import LinearSegmentedColormap |
|
|
28 |
|
|
|
29 |
from captum.attr import IntegratedGradients |
|
|
30 |
from captum.attr import GradientShap |
|
|
31 |
from captum.attr import NoiseTunnel |
|
|
32 |
from captum.attr import Saliency |
|
|
33 |
from captum.attr import visualization as viz |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
SEED = 323 |
|
|
37 |
def seed_everything(seed=SEED): |
|
|
38 |
random.seed(seed) |
|
|
39 |
os.environ['PYHTONHASHSEED'] = str(seed) |
|
|
40 |
np.random.seed(seed) |
|
|
41 |
torch.manual_seed(seed) |
|
|
42 |
torch.cuda.manual_seed(seed) |
|
|
43 |
torch.backends.cudnn.deterministic = True |
|
|
44 |
|
|
|
45 |
#decay |
|
|
46 |
decay = 5e-4 |
|
|
47 |
# training batch size |
|
|
48 |
batch_size = 10 |
|
|
49 |
# check if cuda is available: if not available, then use cpu |
|
|
50 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
class LeukemiaDataset(Dataset): |
|
|
54 |
|
|
|
55 |
""" |
|
|
56 |
Acute Lymphoblastic Leukemia Dataset Reader. |
|
|
57 |
|
|
|
58 |
Args: |
|
|
59 |
df_data: Dataframe for CSV file |
|
|
60 |
data_dir: path to Lymphoblastic Leukemia Data |
|
|
61 |
transform: transforms for performing data augmentation |
|
|
62 |
""" |
|
|
63 |
|
|
|
64 |
def __init__(self, df_data, data_dir='./', transform=None): |
|
|
65 |
super().__init__() |
|
|
66 |
self.df = df_data.values |
|
|
67 |
self.data_dir = data_dir |
|
|
68 |
self.transform = transform |
|
|
69 |
|
|
|
70 |
def __len__(self): |
|
|
71 |
return len(self.df) |
|
|
72 |
|
|
|
73 |
def __getitem__(self, index): |
|
|
74 |
img_name, label = self.df[index] |
|
|
75 |
img_path = os.path.join(self.data_dir, img_name + '.jpg') |
|
|
76 |
image = Image.open(img_path) |
|
|
77 |
if self.transform is not None: |
|
|
78 |
image = self.transform(image) |
|
|
79 |
return image, label |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
def augmentation(): |
|
|
83 |
"""Acute Lymphoblastic Leukemia data augmentation""" |
|
|
84 |
mean, std_dev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
|
|
85 |
training_transforms = transforms.Compose([transforms.Resize((100, 100)), |
|
|
86 |
transforms.RandomRotation(30), |
|
|
87 |
transforms.RandomResizedCrop(100), |
|
|
88 |
transforms.RandomHorizontalFlip(), |
|
|
89 |
transforms.RandomGrayscale(p=0.1), |
|
|
90 |
transforms.ToTensor(), |
|
|
91 |
transforms.Normalize(mean=mean, std=std_dev)]) |
|
|
92 |
|
|
|
93 |
validation_transforms = transforms.Compose([ |
|
|
94 |
transforms.Resize((100, 100)), |
|
|
95 |
transforms.ToTensor(), |
|
|
96 |
transforms.Normalize(mean=mean, std=std_dev)]) |
|
|
97 |
|
|
|
98 |
return training_transforms, validation_transforms |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
|
|
|
102 |
|
|
|
103 |
def predict_probability(model, transforms, image_path, use_gpu=True): |
|
|
104 |
img = Image.open(image_path) |
|
|
105 |
img = img.convert('RGB') |
|
|
106 |
img = transforms(img).unsqueeze(0) |
|
|
107 |
if use_gpu: |
|
|
108 |
image = img.cuda() |
|
|
109 |
|
|
|
110 |
predictions = model(image) |
|
|
111 |
predictions = torch.sigmoid(predictions) |
|
|
112 |
predictions = predictions.detach().cpu().numpy().flatten() |
|
|
113 |
return predictions |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
def show_prediction_confidence(prediction, class_names): |
|
|
117 |
pred_df = pd.DataFrame({ |
|
|
118 |
'class_names': class_names, |
|
|
119 |
'values': prediction |
|
|
120 |
}) |
|
|
121 |
sns.barplot(x='values', y='class_names', data=pred_df, orient='h') |
|
|
122 |
sns |
|
|
123 |
plt.xlim([0, 1]) |
|
|
124 |
|
|
|
125 |
def get_predictions(model, data_loader, use_gpu=True): |
|
|
126 |
model = model.eval() |
|
|
127 |
y_predictions = list() |
|
|
128 |
y_true = list() |
|
|
129 |
with torch.no_grad(): |
|
|
130 |
for i, dataset in enumerate(data_loader): |
|
|
131 |
inputs, labels = dataset |
|
|
132 |
inputs, labels = inputs.to(device), labels.to(device) |
|
|
133 |
|
|
|
134 |
outputs = model(inputs) |
|
|
135 |
_, preds = torch.max(outputs, 1) |
|
|
136 |
y_predictions.extend(preds) |
|
|
137 |
y_true.extend(labels) |
|
|
138 |
|
|
|
139 |
predictions = torch.as_tensor(y_predictions).cpu() |
|
|
140 |
y_true = torch.as_tensor(y_true).cpu() |
|
|
141 |
return predictions, y_true |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
def confusion_matrix2(confusion_matrix, class_names, save_path): |
|
|
145 |
cm = confusion_matrix.copy() |
|
|
146 |
|
|
|
147 |
cell_counts = cm.flatten() |
|
|
148 |
|
|
|
149 |
cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis] |
|
|
150 |
|
|
|
151 |
row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()] |
|
|
152 |
|
|
|
153 |
cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)] |
|
|
154 |
cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1]) |
|
|
155 |
|
|
|
156 |
df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names) |
|
|
157 |
|
|
|
158 |
hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues") |
|
|
159 |
hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right') |
|
|
160 |
hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right') |
|
|
161 |
plt.ylabel('True diagnostic') |
|
|
162 |
plt.xlabel('Predicted diagnostic') |
|
|
163 |
plt.savefig(save_path) |
|
|
164 |
plt.show() |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
color = [(0, '#ffffff'), (0.25, '#000000'), (1, '#000000')] |
|
|
168 |
name = 'custom blue' |
|
|
169 |
N = 256 |
|
|
170 |
|
|
|
171 |
def linear_seg_color_map(name, color, N, gamma=1.0): |
|
|
172 |
""" |
|
|
173 |
Render color map based on lookup tables |
|
|
174 |
:param name: name the color |
|
|
175 |
:param color: color code |
|
|
176 |
:param N: number of RGB quantization |
|
|
177 |
:param gamma: default is 1.0 |
|
|
178 |
:return: |
|
|
179 |
""" |
|
|
180 |
default_cmap = LinearSegmentedColormap.from_list(name, color, N, gamma) |
|
|
181 |
return default_cmap |
|
|
182 |
|
|
|
183 |
def predict(model, transforms, image_path, use_cpu): |
|
|
184 |
""" |
|
|
185 |
|
|
|
186 |
:param model: |
|
|
187 |
:param transforms: data transform |
|
|
188 |
:param image_path: inference input image path |
|
|
189 |
:param use_cpu: |
|
|
190 |
:return: |
|
|
191 |
""" |
|
|
192 |
model.cpu() |
|
|
193 |
model = model.eval() |
|
|
194 |
img = Image.open(image_path) |
|
|
195 |
img = img.convert('RGB') |
|
|
196 |
transformed_img = transforms(img) |
|
|
197 |
image = transformed_img |
|
|
198 |
image = image.unsqueeze(0) |
|
|
199 |
if use_cpu: |
|
|
200 |
image = image.cpu() |
|
|
201 |
elif use_cpu == False: |
|
|
202 |
model.cuda() |
|
|
203 |
image = image.cuda |
|
|
204 |
output = model(image) |
|
|
205 |
|
|
|
206 |
output = torch.softmax(output, dim=1) |
|
|
207 |
prediction_score, pred_label_idx = torch.topk(output, 1) |
|
|
208 |
return image, transformed_img, prediction_score, pred_label_idx |
|
|
209 |
|
|
|
210 |
def interpret_model(model, transforms, image_path, label_path, use_cpu=True, interpret_type=""): |
|
|
211 |
|
|
|
212 |
""" |
|
|
213 |
:param model: our model |
|
|
214 |
:param transforms: Data transformation |
|
|
215 |
:param image_path: Image directory |
|
|
216 |
:param label_path: Json label directory |
|
|
217 |
:param use_gpu: set gpu to True |
|
|
218 |
:param interpret_type: mode for model interpretability: "integrated gradients" |
|
|
219 |
for Integrated Gradients, "gradient shap" for Gradient Shap and "occlusion" for Occlusion |
|
|
220 |
:return: |
|
|
221 |
""" |
|
|
222 |
with open(label_path) as json_data: |
|
|
223 |
idx_to_labels = json.load(json_data) |
|
|
224 |
|
|
|
225 |
# Check if mode is Integrated Gradients |
|
|
226 |
if interpret_type == "integrated gradients": |
|
|
227 |
|
|
|
228 |
print('Performing Integrated Gradients Model Interpretation', interpret_type) |
|
|
229 |
image, transformed_img, prediction_score, pred_label_idx = predict(model, transforms, image_path, use_cpu) |
|
|
230 |
pred_label_idx.squeeze_() |
|
|
231 |
predicted_label = idx_to_labels[str(pred_label_idx.item())][1] |
|
|
232 |
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')') |
|
|
233 |
|
|
|
234 |
integrated_gradients = IntegratedGradients(model) |
|
|
235 |
attributions_ig = integrated_gradients.attribute(image, target=pred_label_idx, n_steps=20) |
|
|
236 |
|
|
|
237 |
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
238 |
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
239 |
["original_image", "heat_map"], |
|
|
240 |
["all", "absolute_value"], |
|
|
241 |
cmap=linear_seg_color_map(name, color, N), |
|
|
242 |
show_colorbar=True) |
|
|
243 |
|
|
|
244 |
# Check if mode is Integrated Gradients with Noise Tunnel |
|
|
245 |
elif interpret_type == "integrated gradient noise": |
|
|
246 |
print('Performing Integrated Gradients Noise Tunnel Model Interpretation', interpret_type) |
|
|
247 |
image, transformed_img, prediction_score, pred_label_idx = predict(model, transforms, image_path, use_cpu) |
|
|
248 |
pred_label_idx.squeeze_() |
|
|
249 |
predicted_label = idx_to_labels[str(pred_label_idx.item())][1] |
|
|
250 |
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')') |
|
|
251 |
|
|
|
252 |
integrated_gradients = IntegratedGradients(model) |
|
|
253 |
noise_tunnel = NoiseTunnel(integrated_gradients) |
|
|
254 |
|
|
|
255 |
attributions_ig_nt = noise_tunnel.attribute(image, n_samples=10, nt_type='smoothgrad_sq', target=pred_label_idx) |
|
|
256 |
_ = viz.visualize_image_attr_multiple( |
|
|
257 |
np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
258 |
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
259 |
["original_image", "heat_map"], |
|
|
260 |
["all", "positive"], |
|
|
261 |
cmap=linear_seg_color_map(name, color, N), |
|
|
262 |
show_colorbar=True) |
|
|
263 |
|
|
|
264 |
|
|
|
265 |
|
|
|
266 |
# Check if mode is Gradient Shap |
|
|
267 |
elif interpret_type == "gradient shap": |
|
|
268 |
|
|
|
269 |
print('Performing Gradient Shap Model Interpretation', interpret_type) |
|
|
270 |
image, transformed_img, prediction_score, pred_label_idx = predict(model,transforms, image_path, use_cpu) |
|
|
271 |
pred_label_idx.squeeze_() |
|
|
272 |
predicted_label = idx_to_labels[str(pred_label_idx.item())][1] |
|
|
273 |
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')') |
|
|
274 |
gradient_shap = GradientShap(model) |
|
|
275 |
|
|
|
276 |
# Defining baseline distribution of images |
|
|
277 |
rand_img_dist = torch.cat([image * 0, image * 1]) |
|
|
278 |
|
|
|
279 |
attributions_gs = gradient_shap.attribute(image, |
|
|
280 |
n_samples=50, |
|
|
281 |
stdevs=0.0001, |
|
|
282 |
baselines=rand_img_dist, |
|
|
283 |
target=pred_label_idx) |
|
|
284 |
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
285 |
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
286 |
["original_image", "heat_map"], |
|
|
287 |
["all", "absolute_value"], |
|
|
288 |
cmap=linear_seg_color_map(name, color, N), |
|
|
289 |
show_colorbar=True) |
|
|
290 |
|
|
|
291 |
# Check if mode is Saliency |
|
|
292 |
elif interpret_type == "saliency": |
|
|
293 |
|
|
|
294 |
print('Performing Saliency Model Interpretation', interpret_type) |
|
|
295 |
image, transformed_img, prediction_score, pred_label_idx = predict(model, transforms, image_path, use_cpu) |
|
|
296 |
pred_label_idx.squeeze_() |
|
|
297 |
predicted_label = idx_to_labels[str(pred_label_idx.item())][1] |
|
|
298 |
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')') |
|
|
299 |
saliency = Saliency(model) |
|
|
300 |
attributions_gs = saliency.attribute(image,target=pred_label_idx) |
|
|
301 |
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
302 |
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1, 2, 0)), |
|
|
303 |
["original_image", "heat_map"], |
|
|
304 |
["all", "absolute_value"], |
|
|
305 |
cmap=linear_seg_color_map(name, color, N), |
|
|
306 |
show_colorbar=True) |
|
|
307 |
|
|
|
308 |
|
|
|
309 |
|
|
|
310 |
class LuekemiaNet(pl.LightningModule): |
|
|
311 |
def __init__(self, lr=0.01, weight_decay=5e-4): |
|
|
312 |
super(LuekemiaNet, self).__init__() |
|
|
313 |
|
|
|
314 |
self.save_hyperparameters() |
|
|
315 |
self.conv1 = nn.Sequential( |
|
|
316 |
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=0), |
|
|
317 |
nn.ReLU(), |
|
|
318 |
nn.BatchNorm2d(32), |
|
|
319 |
nn.Dropout(p=0.25), |
|
|
320 |
nn.AvgPool2d(2)) |
|
|
321 |
|
|
|
322 |
self.conv2 = nn.Sequential( |
|
|
323 |
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=1), |
|
|
324 |
nn.ReLU(), |
|
|
325 |
nn.BatchNorm2d(64), |
|
|
326 |
nn.Dropout(p=0.25), |
|
|
327 |
nn.AvgPool2d(2)) |
|
|
328 |
|
|
|
329 |
self.conv3 = nn.Sequential( |
|
|
330 |
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1), |
|
|
331 |
nn.ReLU(), |
|
|
332 |
nn.BatchNorm2d(128), |
|
|
333 |
nn.Dropout(p=0.25), |
|
|
334 |
nn.AvgPool2d(2)) |
|
|
335 |
|
|
|
336 |
|
|
|
337 |
self.fc = nn.Sequential( |
|
|
338 |
nn.Linear(128 * 10 * 10, 200), |
|
|
339 |
nn.ReLU(), |
|
|
340 |
nn.Dropout(), |
|
|
341 |
nn.Linear(200, 2)) |
|
|
342 |
|
|
|
343 |
def forward(self, x): |
|
|
344 |
"""Method for Forward Prop""" |
|
|
345 |
out = self.conv1(x) |
|
|
346 |
out = self.conv2(out) |
|
|
347 |
out = self.conv3(out) |
|
|
348 |
###################### |
|
|
349 |
# For model debugging # |
|
|
350 |
###################### |
|
|
351 |
#print(out.shape) |
|
|
352 |
|
|
|
353 |
out = out.view(x.shape[0], -1) |
|
|
354 |
out = self.fc(out) |
|
|
355 |
return out |
|
|
356 |
|
|
|
357 |
def configure_optimizers(self): |
|
|
358 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) |
|
|
359 |
return optimizer |
|
|
360 |
|
|
|
361 |
def training_step(self, batch, batch_idx): |
|
|
362 |
x, y = batch |
|
|
363 |
logits = self(x) |
|
|
364 |
loss = F.cross_entropy(logits, y) |
|
|
365 |
self.log('train_loss', loss) |
|
|
366 |
return loss |
|
|
367 |
|
|
|
368 |
def evaluate(self, batch, stage=None): |
|
|
369 |
x, y = batch |
|
|
370 |
logits = self(x) |
|
|
371 |
loss = F.cross_entropy(logits, y) |
|
|
372 |
preds = torch.argmax(logits, dim=1) |
|
|
373 |
acc = accuracy(preds, y) |
|
|
374 |
|
|
|
375 |
if stage: |
|
|
376 |
self.log(f'{stage}_loss', loss, prog_bar=True) |
|
|
377 |
self.log(f'{stage}_acc', acc, prog_bar=True) |
|
|
378 |
|
|
|
379 |
def validation_step(self, batch, batch_idx): |
|
|
380 |
self.evaluate(batch, 'val') |
|
|
381 |
|
|
|
382 |
def test_step(self, batch, batch_idx): |
|
|
383 |
self.evaluate(batch, 'test') |
|
|
384 |
|
|
|
385 |
|
|
|
386 |
|
|
|
387 |
|
|
|
388 |
image_path = '/home/allen/Drive C/Peter Moss AML Leukemia Research/Dataset/all_test/Im041_0.jpg' |
|
|
389 |
label_idx = '/home/allen/Drive C/Peter Moss AML Leukemia Research/ALL-PyTorch-2020/Classifier/Model/class_idx.json' |
|
|
390 |
|
|
|
391 |
|
|
|
392 |
seed_everything(SEED) |
|
|
393 |
# train data directory |
|
|
394 |
train_dir = '/home/allen/Drive C/Peter Moss AML Leukemia Research/Dataset/all_train/' |
|
|
395 |
# train label directoy |
|
|
396 |
train_csv = '/home/allen/Drive C/Peter Moss AML Leukemia Research/Dataset/train.csv' |
|
|
397 |
# labels |
|
|
398 |
class_name = ["zero", "one"] |
|
|
399 |
|
|
|
400 |
# number of epoch |
|
|
401 |
epochs = 20 |
|
|
402 |
# learning rate |
|
|
403 |
learn_rate = 0.001 |
|
|
404 |
# read train CSV file |
|
|
405 |
|
|
|
406 |
labels = pd.read_csv(train_csv) |
|
|
407 |
# print label count |
|
|
408 |
labels_count = labels.label.value_counts() |
|
|
409 |
print(labels_count) |
|
|
410 |
# print 5 label header |
|
|
411 |
print(labels.head()) |
|
|
412 |
# splitting data into training and validation set |
|
|
413 |
train, valid = train_test_split(labels, stratify = labels.label, test_size = 0.1, shuffle=True) |
|
|
414 |
print(len(train),len(valid)) |
|
|
415 |
#data augmentation |
|
|
416 |
training_transforms, validation_transforms = augmentation() |
|
|
417 |
|
|
|
418 |
train_dataset = LeukemiaDataset(df_data=train, data_dir=train_dir, transform=training_transforms) |
|
|
419 |
valid_dataset = LeukemiaDataset(df_data=valid, data_dir=train_dir, transform=validation_transforms) |
|
|
420 |
train_sampler = SubsetRandomSampler(list(train.index)) |
|
|
421 |
valid_sampler = SubsetRandomSampler(list(valid.index)) |
|
|
422 |
# Prepare dataset for neural networks |
|
|
423 |
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) |
|
|
424 |
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4) |
|
|
425 |
|
|
|
426 |
model_path = "weight2.pth" |
|
|
427 |
model = LuekemiaNet() |
|
|
428 |
early_stop_callback = EarlyStopping( |
|
|
429 |
monitor='val_acc', |
|
|
430 |
min_delta=0.00, |
|
|
431 |
patience=3, |
|
|
432 |
verbose=False, |
|
|
433 |
mode='max' |
|
|
434 |
) |
|
|
435 |
trainer = pl.Trainer(max_epochs=epochs, log_every_n_steps=2, callbacks=[early_stop_callback]) |
|
|
436 |
trainer.fit(model, train_data_loader, valid_data_loader) |
|
|
437 |
trainer.save_checkpoint(model_path) |
|
|
438 |
|
|
|
439 |
real_model = model.load_from_checkpoint(model_path) |
|
|
440 |
|
|
|
441 |
y_pred, y_test = get_predictions(real_model, valid_data_loader) |
|
|
442 |
# Get model precision, recall and f1_score |
|
|
443 |
print(classification_report(y_test, y_pred, target_names=class_name)) |
|
|
444 |
# Get model confusion matrix |
|
|
445 |
cm = confusion_matrix(y_test, y_pred) |
|
|
446 |
confusion_matrix2(cm, class_name,save_path='confusion_matrix.png') |
|
|
447 |
|
|
|
448 |
#prediction = predict_probability(real_model, validation_transforms, image_path) |
|
|
449 |
interpret_model(real_model, validation_transforms, image_path, label_idx, use_cpu=True, interpret_type="integrated gradients") |
|
|
450 |
interpret_model(real_model, validation_transforms, image_path, label_idx, use_cpu=True, interpret_type="gradient shap") |
|
|
451 |
interpret_model(real_model, validation_transforms, image_path, label_idx, use_cpu=True, interpret_type="saliency") |
|
|
452 |
|
|
|
453 |
|