Diff of /Learning/load_data.py [000000] .. [b623ff]

Switch to unified view

a b/Learning/load_data.py
1
import os
2
import pandas as pd
3
import h5py
4
5
import torch
6
from torch.utils.data import Dataset
7
8
def load_feature(path):
9
    with h5py.File(path, 'r') as f:
10
        X = f['X'][()]
11
    X = torch.FloatTensor(X)
12
    return X
13
14
class VoxelDataset(Dataset):
15
16
    def __init__(self, label_file, root_dir):
17
        """
18
        Args:
19
            label_file (string): Path to the csv file with labels.
20
            root_dir (string): Directory with all the voxel data.
21
            transform (callable, optional): Optional transform to be applied
22
                on a sample.
23
        """
24
        self.labels = pd.read_csv(label_file)
25
        self.root_dir = root_dir
26
27
    def __len__(self):
28
        return len(self.labels)
29
30
    def __getitem__(self, idx):
31
        if torch.is_tensor(idx):
32
            idx = idx.tolist()
33
34
        voxel_name = os.path.join(self.root_dir,
35
                                self.labels.iloc[idx, 0]) + '.h5'
36
        voxel = load_feature(voxel_name)
37
        label = self.labels.iloc[idx, 1]
38
        sample = {'voxel': voxel, 'label': label}
39
40
        return sample