--- a +++ b/Learning/load_data.py @@ -0,0 +1,40 @@ +import os +import pandas as pd +import h5py + +import torch +from torch.utils.data import Dataset + +def load_feature(path): + with h5py.File(path, 'r') as f: + X = f['X'][()] + X = torch.FloatTensor(X) + return X + +class VoxelDataset(Dataset): + + def __init__(self, label_file, root_dir): + """ + Args: + label_file (string): Path to the csv file with labels. + root_dir (string): Directory with all the voxel data. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.labels = pd.read_csv(label_file) + self.root_dir = root_dir + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + voxel_name = os.path.join(self.root_dir, + self.labels.iloc[idx, 0]) + '.h5' + voxel = load_feature(voxel_name) + label = self.labels.iloc[idx, 1] + sample = {'voxel': voxel, 'label': label} + + return sample