|
a |
|
b/knee_dataset_patches.py |
|
|
1 |
import os |
|
|
2 |
import os.path |
|
|
3 |
import numpy as np |
|
|
4 |
import random |
|
|
5 |
import h5py |
|
|
6 |
import torch |
|
|
7 |
import glob |
|
|
8 |
import torch.utils.data as udata |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class MRI_Dataset(udata.Dataset): |
|
|
12 |
def __init__(self, x_path, y_path): |
|
|
13 |
super(MRI_Dataset, self).__init__() |
|
|
14 |
|
|
|
15 |
self.x_path = x_path; |
|
|
16 |
self.y_path = y_path; |
|
|
17 |
|
|
|
18 |
h5f_x = h5py.File(self.x_path, 'r') |
|
|
19 |
h5f_y = h5py.File(self.y_path, 'r') |
|
|
20 |
|
|
|
21 |
self.keys = list(h5f_x.keys()) |
|
|
22 |
random.shuffle(self.keys) |
|
|
23 |
h5f_x.close() |
|
|
24 |
h5f_y.close() |
|
|
25 |
|
|
|
26 |
def __len__(self): |
|
|
27 |
return len(self.keys) |
|
|
28 |
def __getitem__(self, index): |
|
|
29 |
|
|
|
30 |
h5f_x = h5py.File(self.x_path, 'r') |
|
|
31 |
h5f_y = h5py.File(self.y_path, 'r') |
|
|
32 |
|
|
|
33 |
key = self.keys[index] |
|
|
34 |
data_x = np.array(h5f_x[key]) |
|
|
35 |
data_y = np.array(h5f_y[key]) |
|
|
36 |
|
|
|
37 |
h5f_x.close(); |
|
|
38 |
h5f_y.close(); |
|
|
39 |
return (torch.Tensor(data_x), torch.Tensor(data_y),key) |