Diff of /dataloaders/la_heart.py [000000] .. [903821]

Switch to unified view

a b/dataloaders/la_heart.py
1
import h5py
2
import torch
3
import numpy as np
4
import itertools
5
from torch.utils.data import Dataset
6
from torch.utils.data.sampler import Sampler
7
8
9
class LAHeart(Dataset):
10
    """ LA Dataset """
11
12
    def __init__(self, base_dir=None, split='train', num=None, transform=None):
13
        self._base_dir = base_dir
14
        self.transform = transform
15
        self.sample_list = []
16
        if split == 'train':
17
            with open(self._base_dir + '/../train.list', 'r') as f:
18
                self.image_list = f.readlines()
19
        elif split == 'test':
20
            with open(self._base_dir + '/../test.list', 'r') as f:
21
                self.image_list = f.readlines()
22
        self.image_list = [item.strip() for item in self.image_list]
23
        if num is not None:
24
            self.image_list = self.image_list[:num]
25
        print("total {} samples".format(len(self.image_list)))
26
27
    def __len__(self):
28
        return len(self.image_list)
29
30
    def __getitem__(self, idx):
31
        image_name = self.image_list[idx]
32
        # print(image_name)
33
        h5f = h5py.File(self._base_dir + "/" + image_name + "/mri_norm2.h5", 'r')
34
        image = h5f['image'][:]
35
        label = h5f['label'][:]
36
        sample = {'image': image, 'label': label}
37
        if self.transform:
38
            sample = self.transform(sample)
39
        return sample
40
41
class RandomCrop(object):
42
    """
43
    Crop randomly the image in a sample
44
    Args:
45
    output_size (int): Desired output size
46
    """
47
48
    def __init__(self, output_size):
49
        self.output_size = output_size
50
51
    def __call__(self, sample):
52
        image, label = sample['image'], sample['label']
53
54
        # pad the sample if necessary
55
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
56
                self.output_size[2]:
57
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
58
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
59
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
60
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
61
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
62
63
        (w, h, d) = image.shape
64
        w1 = np.random.randint(0, w - self.output_size[0])
65
        h1 = np.random.randint(0, h - self.output_size[1])
66
        d1 = np.random.randint(0, d - self.output_size[2])
67
68
        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
69
        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
70
        return {'image': image, 'label': label}
71
72
class CenterCrop(object):
73
    def __init__(self, output_size):
74
        self.output_size = output_size
75
76
    def __call__(self, sample):
77
        image, label = sample['image'], sample['label']
78
79
        # pad the sample if necessary
80
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
81
                self.output_size[2]:
82
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
83
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
84
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
85
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
86
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
87
88
        (w, h, d) = image.shape
89
90
        w1 = int(round((w - self.output_size[0]) / 2.))
91
        h1 = int(round((h - self.output_size[1]) / 2.))
92
        d1 = int(round((d - self.output_size[2]) / 2.))
93
94
        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
95
        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
96
97
        return {'image': image, 'label': label}
98
99
100
class RandomRotFlip(object):
101
    """
102
    Crop randomly flip the dataset in a sample
103
    Args:
104
    output_size (int): Desired output size
105
    """
106
107
    def __call__(self, sample):
108
        image, label = sample['image'], sample['label']
109
        k = np.random.randint(0, 4)
110
        image = np.rot90(image, k)
111
        label = np.rot90(label, k)
112
        axis = np.random.randint(0, 2)
113
        image = np.flip(image, axis=axis).copy()
114
        label = np.flip(label, axis=axis).copy()
115
116
        return {'image': image, 'label': label}
117
118
119
class ToTensor(object):
120
    """Convert ndarrays in sample to Tensors."""
121
122
    def __call__(self, sample):
123
        image = sample['image']
124
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
125
        return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()}
126
127
128
class TwoStreamBatchSampler(Sampler):
129
    """Iterate two sets of indices
130
131
    An 'epoch' is one iteration through the primary indices.
132
    During the epoch, the secondary indices are iterated through
133
    as many times as needed.
134
    """
135
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
136
        self.primary_indices = primary_indices
137
        self.secondary_indices = secondary_indices
138
        self.secondary_batch_size = secondary_batch_size
139
        self.primary_batch_size = batch_size - secondary_batch_size
140
141
        assert len(self.primary_indices) >= self.primary_batch_size > 0
142
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
143
144
    def __iter__(self):
145
        primary_iter = iterate_once(self.primary_indices)
146
        secondary_iter = iterate_eternally(self.secondary_indices)
147
        return (
148
            primary_batch + secondary_batch
149
            for (primary_batch, secondary_batch)
150
            in zip(grouper(primary_iter, self.primary_batch_size),
151
                    grouper(secondary_iter, self.secondary_batch_size))
152
        )
153
154
    def __len__(self):
155
        return len(self.primary_indices) // self.primary_batch_size
156
157
def iterate_once(iterable):
158
    return np.random.permutation(iterable)
159
160
161
def iterate_eternally(indices):
162
    def infinite_shuffles():
163
        while True:
164
            yield np.random.permutation(indices)
165
    return itertools.chain.from_iterable(infinite_shuffles())
166
167
168
def grouper(iterable, n):
169
    "Collect data into fixed-length chunks or blocks"
170
    # grouper('ABCDEFG', 3) --> ABC DEF"
171
    args = [iter(iterable)] * n
172
    return zip(*args)
173
174
175
if __name__ == '__main__':
176
    train_set = LAHeart('E:/data/LASet/data')
177
    print(len(train_set))
178
    # data = train_set[0]
179
    # image, label = data['image'], data['label']
180
    # print(image.shape, label.shape)
181
    labeled_idxs = list(range(25))
182
    unlabeled_idxs = list(range(25,123))
183
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2)
184
    i = 0
185
    for x in batch_sampler:
186
        i += 1
187
        print('%02d'%i,'\t',x)