[94d9b6]: / knee_dataset_patches.py

Download this file

40 lines (30 with data), 914 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import os.path
import numpy as np
import random
import h5py
import torch
import glob
import torch.utils.data as udata
class MRI_Dataset(udata.Dataset):
def __init__(self, x_path, y_path):
super(MRI_Dataset, self).__init__()
self.x_path = x_path;
self.y_path = y_path;
h5f_x = h5py.File(self.x_path, 'r')
h5f_y = h5py.File(self.y_path, 'r')
self.keys = list(h5f_x.keys())
random.shuffle(self.keys)
h5f_x.close()
h5f_y.close()
def __len__(self):
return len(self.keys)
def __getitem__(self, index):
h5f_x = h5py.File(self.x_path, 'r')
h5f_y = h5py.File(self.y_path, 'r')
key = self.keys[index]
data_x = np.array(h5f_x[key])
data_y = np.array(h5f_y[key])
h5f_x.close();
h5f_y.close();
return (torch.Tensor(data_x), torch.Tensor(data_y),key)