Diff of /knee_dataset_patches.py [000000] .. [94d9b6]

Switch to unified view

a b/knee_dataset_patches.py
1
import os
2
import os.path
3
import numpy as np
4
import random
5
import h5py
6
import torch
7
import glob
8
import torch.utils.data as udata
9
10
11
class MRI_Dataset(udata.Dataset):
12
    def __init__(self, x_path, y_path):
13
        super(MRI_Dataset, self).__init__()
14
15
        self.x_path = x_path;
16
        self.y_path = y_path;
17
  
18
        h5f_x = h5py.File(self.x_path, 'r')
19
        h5f_y = h5py.File(self.y_path, 'r')
20
21
        self.keys = list(h5f_x.keys())
22
        random.shuffle(self.keys)
23
        h5f_x.close()
24
        h5f_y.close()
25
26
    def __len__(self):
27
        return len(self.keys)
28
    def __getitem__(self, index):
29
30
        h5f_x = h5py.File(self.x_path, 'r')
31
        h5f_y = h5py.File(self.y_path, 'r')
32
33
        key = self.keys[index]
34
        data_x = np.array(h5f_x[key])
35
        data_y = np.array(h5f_y[key])
36
37
        h5f_x.close();
38
        h5f_y.close();
39
        return (torch.Tensor(data_x), torch.Tensor(data_y),key)