Diff of /evaluate_ensemble.py [000000] .. [94d9b6]

Switch to unified view

a b/evaluate_ensemble.py
1
import time
2
import torch
3
import torch.nn as nn
4
from torch.autograd import Variable
5
import torch.optim as optim
6
import torchvision
7
from torchvision import datasets, models
8
from torchvision import transforms as T
9
from torch.utils.data import DataLoader, Dataset
10
import numpy as np
11
import matplotlib.pyplot as plt
12
import os
13
import time
14
import pandas as pd
15
from skimage import io, transform
16
import matplotlib.image as mpimg
17
from PIL import Image
18
from sklearn.metrics import roc_auc_score
19
import torch.nn.functional as F
20
import scipy
21
import random
22
import pickle
23
import scipy.io as sio
24
import itertools
25
from scipy.ndimage.interpolation import shift
26
import copy
27
import warnings
28
warnings.filterwarnings("ignore")
29
plt.ion()
30
31
from utils import *
32
33
34
def evaluate_pp(model,prediction_models, dataloader, data_size, batch_size, phase, dice_loss = dice_loss,\
35
                smooth = False, filter_size = 3, print_all = False, certainity_map = False):
36
    model.eval()
37
    running_loss = 0
38
    running_dice_score_class_0 = 0
39
    running_dice_score_class_1 = 0
40
    running_dice_score_class_2 = 0
41
    dc_sr = {0:[],1:[],2:[]}
42
    acc_sr = {0:[],1:[],2:[]}
43
    phase = phase
44
    for i in prediction_models:
45
        for param in i.parameters():
46
            param.requires_grad = False
47
    
48
    for i,data in enumerate(dataloader[phase]):
49
        input, segF,segP, segT,_ = data
50
        input = Variable(input).cuda()
51
        input_pp = []
52
        for j in prediction_models:
53
            output = j(input)
54
            preds_m = predict_pp(output)
55
            input_pp.append(preds_m)
56
        input_pp = torch.cat(input_pp,dim = 1)
57
        output_pp = model(input_pp)
58
        true = Variable(segments(segF, segP, segT)).cuda()
59
        loss = dice_loss(true,output_pp)
60
        running_loss += loss.data[0] * batch_size
61
        dice_score_batch = dice_score(true,output_pp, smooth= smooth, filter_size=filter_size)
62
        running_dice_score_class_0 += dice_score_batch[0] * batch_size
63
        running_dice_score_class_1 += dice_score_batch[1] * batch_size
64
        running_dice_score_class_2 += dice_score_batch[2] * batch_size
65
        dc_dict, acc_dict = dice_score_list(true,output_pp)
66
        if certainity_map:
67
            cm = make_certainity_maps(output_pp)
68
        for k in range(3):
69
            dc_sr[k].append(dc_dict[k])
70
            acc_sr[k].append(acc_dict[k])
71
        preds = predict(output_pp,smooth = smooth, filter_size=filter_size)
72
        if i == 11 or i == 4 or print_all:
73
            for k in range(batch_size):
74
                if certainity_map:
75
                    image_to_mask(input[k,1,:,:].data.cpu().numpy(),\
76
                              true[k,0,:,:].data.cpu().numpy(),\
77
                              true[k,1,:,:].data.cpu().numpy(),\
78
                              true[k,2,:,:].data.cpu().numpy(),\
79
                             preds[0][k,:,:].cpu().numpy(),\
80
                             preds[1][k,:,:].cpu().numpy(),\
81
                             preds[2][k,:,:].cpu().numpy(),\
82
                            cm[k].data.cpu().numpy())
83
                else:
84
                    image_to_mask(input[k,1,:,:].data.cpu().numpy(),\
85
                              true[k,0,:,:].data.cpu().numpy(),\
86
                              true[k,1,:,:].data.cpu().numpy(),\
87
                              true[k,2,:,:].data.cpu().numpy(),\
88
                             preds[0][k,:,:].cpu().numpy(),\
89
                             preds[1][k,:,:].cpu().numpy(),\
90
                             preds[2][k,:,:].cpu().numpy())
91
92
    loss = running_loss/data_sizes[phase] 
93
    dice_score_0 = running_dice_score_class_0/data_sizes[phase]
94
    dice_score_1 = running_dice_score_class_1/data_sizes[phase]
95
    dice_score_2 = running_dice_score_class_2/data_sizes[phase]
96
    for i in range(3):
97
        dc_sr[i] = list(itertools.chain(*dc_sr[i]))
98
        acc_sr[i] = list(itertools.chain(*acc_sr[i]))
99
    print('{} loss: {:.4f}, Dice Score (class 0): {:.4f}, Dice Score (class 1): {:.4f},Dice Score (class 2): {:.4f}'.format(phase,loss, dice_score_0, dice_score_1, dice_score_2))
100
    return loss, dice_score_0, dice_score_1, dice_score_2, dc_sr, acc_sr