--- a +++ b/knee_dataset_patches.py @@ -0,0 +1,39 @@ +import os +import os.path +import numpy as np +import random +import h5py +import torch +import glob +import torch.utils.data as udata + + +class MRI_Dataset(udata.Dataset): + def __init__(self, x_path, y_path): + super(MRI_Dataset, self).__init__() + + self.x_path = x_path; + self.y_path = y_path; + + h5f_x = h5py.File(self.x_path, 'r') + h5f_y = h5py.File(self.y_path, 'r') + + self.keys = list(h5f_x.keys()) + random.shuffle(self.keys) + h5f_x.close() + h5f_y.close() + + def __len__(self): + return len(self.keys) + def __getitem__(self, index): + + h5f_x = h5py.File(self.x_path, 'r') + h5f_y = h5py.File(self.y_path, 'r') + + key = self.keys[index] + data_x = np.array(h5f_x[key]) + data_y = np.array(h5f_y[key]) + + h5f_x.close(); + h5f_y.close(); + return (torch.Tensor(data_x), torch.Tensor(data_y),key)