Diff of /dataloader_3d.py [000000] .. [94d9b6]

Switch to unified view

a b/dataloader_3d.py
1
import torch
2
import torch.nn as nn
3
from torch.autograd import Variable
4
import torch.optim as optim
5
import torchvision
6
from torchvision import datasets, models
7
from torchvision import transforms as T
8
from torch.utils.data import DataLoader, Dataset
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import os
12
import time
13
import pandas as pd
14
from skimage import io, transform
15
import matplotlib.image as mpimg
16
from PIL import Image
17
from sklearn.metrics import roc_auc_score
18
import torch.nn.functional as F
19
import scipy
20
import random
21
import pickle
22
import scipy.io as sio
23
import itertools
24
from scipy.ndimage.interpolation import shift
25
26
import warnings
27
warnings.filterwarnings("ignore")
28
plt.ion()
29
30
class KneeMRI3DDataset(Dataset):
31
    '''Knee MRI Dataset'''
32
    def __init__(self, root_dir, label, train_data = False, flipping = True, rotation = True, translation = True,
33
                normalize = False):
34
        self.root_dir = root_dir
35
        self.label = label
36
        self.flipping = flipping
37
        self.rotation = rotation
38
        self.translation = translation
39
        self.train_data = train_data
40
        self.normalize = normalize
41
    
42
    def __len__(self):
43
        return len(self.label)
44
    
45
    def __getitem__(self,idx):
46
        variable_path_name = os.path.join(self.root_dir, self.label[idx])
47
        variables = sio.loadmat(variable_path_name)
48
        segment_T = variables['SegmentationT'].transpose(2,0,1).astype(float)
49
        segment_F = variables['SegmentationF'].transpose(2,0,1).astype(float)
50
        segment_P = variables['SegmentationP'].transpose(2,0,1).astype(float)
51
        images = []
52
        md = variables['MDnr'].transpose(2,0,1)
53
        image = variables['NUFnr'].transpose(3,2,0,1)
54
        fa = variables['FAnr'].transpose(2,0,1)
55
        image = torch.from_numpy(image).type(torch.FloatTensor)
56
        fa = torch.from_numpy(fa).type(torch.FloatTensor)
57
        md = torch.from_numpy(md).type(torch.FloatTensor)
58
        images.append(image)
59
        images.append(fa.unsqueeze(0))
60
        images.append(md.unsqueeze(0))
61
        image_all = torch.cat(images, dim = 0)
62
        segment_F = torch.from_numpy(segment_F).type(torch.FloatTensor)
63
        segment_T = torch.from_numpy(segment_T).type(torch.FloatTensor)
64
        segment_P = torch.from_numpy(segment_P).type(torch.FloatTensor)
65
        segments = []
66
        segments.append(segment_F.unsqueeze(0))
67
        segments.append(segment_T.unsqueeze(0))
68
        segments.append(segment_P.unsqueeze(0))
69
        seg_tot = segment_F + segment_T + segment_P
70
        seg_none = (seg_tot == 0).type(torch.FloatTensor)
71
        segments.append(seg_none.unsqueeze(0))
72
        segments_all = torch.cat(segments, dim = 0)
73
        
74
        if self.normalize:
75
            max_image, min_image, max_fa, min_fa, max_md, min_md = pickle.load(open('normalizing_values_new','rb'))
76
            image_all[:7,:,:,:] = (image_all[:7,:,:,:] - min_image)/(max_image - min_image)
77
            image_all[7,:,:,:] = (image_all[7,:,:,:] - min_fa)/(max_fa - min_fa)
78
            image_all[-1,:,:,:] = (image_all[-1,:,:,:] - min_md)/(max_md - min_md)
79
        
80
        return (image_all,segments_all,self.label[idx])