Diff of /dataloader_2d.py [000000] .. [94d9b6]

Switch to unified view

a b/dataloader_2d.py
1
import torch
2
import torch.nn as nn
3
from torch.autograd import Variable
4
import torch.optim as optim
5
import torchvision
6
from torchvision import datasets, models
7
from torchvision import transforms as T
8
from torch.utils.data import DataLoader, Dataset
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import os
12
import time
13
import pandas as pd
14
from skimage import io, transform
15
import matplotlib.image as mpimg
16
from PIL import Image
17
from sklearn.metrics import roc_auc_score
18
import torch.nn.functional as F
19
import scipy
20
import random
21
import pickle
22
import scipy.io as sio
23
import itertools
24
from scipy.ndimage.interpolation import shift
25
import copy
26
import warnings
27
warnings.filterwarnings("ignore")
28
plt.ion()
29
30
def count_parameters(model):
31
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
32
33
class KneeMRIDataset(Dataset):
34
    '''Knee MRI Dataset'''
35
    def __init__(self, root_dir, label, train_data = False, flipping = True, rotation = True, translation = True,
36
                normalize = False):
37
        self.root_dir = root_dir
38
        self.label = label
39
        self.flipping = flipping
40
        self.rotation = rotation
41
        self.translation = translation
42
        self.train_data = train_data
43
        self.normalize = normalize
44
    
45
    def __len__(self):
46
        return len(self.label)
47
    
48
    def __getitem__(self,idx):
49
        variable_path_name = os.path.join(self.root_dir, self.label[idx])
50
        variables = pickle.load(open(variable_path_name,'rb'))
51
        segment_T = variables['SegmentationT'].astype(float)
52
        segment_F = variables['SegmentationF'].astype(float)
53
        segment_P = variables['SegmentationP'].astype(float)
54
        md = variables['MDnr']
55
        image = variables['NUFnr']
56
        fa = variables['FAnr']
57
        images = []
58
        flip = random.random() > 0.5
59
        angle = random.uniform(-4,4)
60
        dx = np.round(random.uniform(-7,7))
61
        dy = np.round(random.uniform(-7,7))
62
        for i in range(7):
63
            im = image[:,:,i]
64
            if self.train_data:
65
                if self.flipping and flip:
66
                    im = np.fliplr(im)
67
                if self.rotation:
68
                    im = transform.rotate(im, angle, order = 0)
69
                if self.translation:
70
                    im = shift(im,(dx,dy), order = 0)
71
            im = torch.from_numpy(im).type(torch.DoubleTensor)
72
            images.append(im.unsqueeze(0))
73
        if self.train_data:
74
            if self.flipping and flip:
75
                segment_T = np.fliplr(segment_T)
76
                segment_F = np.fliplr(segment_F)
77
                segment_P = np.fliplr(segment_P)
78
                md = np.fliplr(md)
79
                fa = np.fliplr(fa)
80
            if self.rotation:
81
                segment_T = transform.rotate(segment_T,angle, order = 0)
82
                segment_F = transform.rotate(segment_F,angle, order = 0)
83
                segment_P = transform.rotate(segment_P,angle, order = 0)
84
                md = transform.rotate(md,angle, order = 0)
85
                fa = transform.rotate(fa,angle, order = 0)
86
            if self.translation:
87
                segment_T = shift(segment_T,(dx,dy), order = 0)
88
                segment_F = shift(segment_F,(dx,dy), order = 0)
89
                segment_P = shift(segment_P,(dx,dy), order = 0)
90
                md = shift(md,(dx,dy), order = 0)
91
                fa = shift(fa,(dx,dy), order = 0)
92
        fa = torch.from_numpy(fa).type(torch.DoubleTensor)
93
        md = torch.from_numpy(md).type(torch.DoubleTensor)
94
        images.append(fa.unsqueeze(0))
95
        images.append(md.unsqueeze(0))
96
        image_all = torch.cat(images, dim = 0)
97
        segment_F = torch.from_numpy(segment_F)
98
        segment_T = torch.from_numpy(segment_T)
99
        segment_P = torch.from_numpy(segment_P)
100
        if self.normalize:
101
#             max_image, min_image, max_fa, min_fa, max_md, min_md = pickle.load(open('normalizing_values','rb'))
102
            max_image, min_image = torch.max(image_all[:7,:,:]), torch.min(image_all[:7,:,:])
103
            max_fa, min_fa = torch.max(image_all[7,:,:]), torch.min(image_all[7,:,:])
104
            max_md, min_md = torch.max(image_all[8,:,:]), torch.min(image_all[8,:,:])
105
            image_all[:7,:,:] = (image_all[:7,:,:] - min_image)/(max_image - min_image)
106
            image_all[7,:,:] = (image_all[7,:,:] - min_fa)/(max_fa - min_fa)
107
            image_all[-1,:,:] = (image_all[-1,:,:] - min_md)/(max_md - min_md)
108
        
109
        return (image_all.type(torch.FloatTensor), segment_F.type(torch.FloatTensor), 
110
                segment_P.type(torch.FloatTensor), segment_T.type(torch.FloatTensor),self.label[idx])
111