a b/echonet/datasets/echo.py
1
"""EchoNet-Dynamic Dataset."""
2
3
import os
4
import collections
5
import pandas
6
7
import numpy as np
8
import skimage.draw
9
import torchvision
10
import echonet
11
12
13
class Echo(torchvision.datasets.VisionDataset):
14
    """EchoNet-Dynamic Dataset.
15
16
    Args:
17
        root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`)
18
        split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''}
19
        target_type (string or list, optional): Type of target to use,
20
            ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'',
21
            ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'',
22
            or ``SmallTrace''
23
            Can also be a list to output a tuple with all specified target types.
24
            The targets represent:
25
                ``Filename'' (string): filename of video
26
                ``EF'' (float): ejection fraction
27
                ``EDV'' (float): end-diastolic volume
28
                ``ESV'' (float): end-systolic volume
29
                ``LargeIndex'' (int): index of large (diastolic) frame in video
30
                ``SmallIndex'' (int): index of small (systolic) frame in video
31
                ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame
32
                ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame
33
                ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation
34
                    value of 0 indicates pixel is outside left ventricle
35
                             1 indicates pixel is inside left ventricle
36
                ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation
37
                    value of 0 indicates pixel is outside left ventricle
38
                             1 indicates pixel is inside left ventricle
39
            Defaults to ``EF''.
40
        mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel.
41
            Used for normalizing the video. Defaults to 0 (video is not shifted).
42
        std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel.
43
            Used for normalizing the video. Defaults to 0 (video is not scaled).
44
        length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned.
45
            Defaults to 16.
46
        period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken)
47
            Defaults to 2.
48
        max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively
49
            long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video.
50
            Defaults to 250.
51
        clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips.
52
            Defaults to 1.
53
        pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation).
54
            and a window of the original size is taken. If ``None'', no padding occurs.
55
            Defaults to ``None''.
56
        noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added.
57
            Defaults to ``None''.
58
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
59
        external_test_location (string): Path to videos to use for external testing.
60
    """
61
62
    def __init__(self, root=None,
63
                 split="train", target_type="EF",
64
                 mean=0., std=1.,
65
                 length=16, period=2,
66
                 max_length=250,
67
                 clips=1,
68
                 pad=None,
69
                 noise=None,
70
                 target_transform=None,
71
                 external_test_location=None):
72
        if root is None:
73
            root = echonet.config.DATA_DIR
74
75
        super().__init__(root, target_transform=target_transform)
76
77
        self.split = split.upper()
78
        if not isinstance(target_type, list):
79
            target_type = [target_type]
80
        self.target_type = target_type
81
        self.mean = mean
82
        self.std = std
83
        self.length = length
84
        self.max_length = max_length
85
        self.period = period
86
        self.clips = clips
87
        self.pad = pad
88
        self.noise = noise
89
        self.target_transform = target_transform
90
        self.external_test_location = external_test_location
91
92
        self.fnames, self.outcome = [], []
93
94
        if self.split == "EXTERNAL_TEST":
95
            self.fnames = sorted(os.listdir(self.external_test_location))
96
        else:
97
            # Load video-level labels
98
            with open(os.path.join(self.root, "FileList.csv")) as f:
99
                data = pandas.read_csv(f)
100
            data["Split"].map(lambda x: x.upper())
101
102
            if self.split != "ALL":
103
                data = data[data["Split"] == self.split]
104
105
            self.header = data.columns.tolist()
106
            self.fnames = data["FileName"].tolist()
107
            self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""]  # Assume avi if no suffix
108
            self.outcome = data.values.tolist()
109
110
            # Check that files are present
111
            missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos")))
112
            if len(missing) != 0:
113
                print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos")))
114
                for f in sorted(missing):
115
                    print("\t", f)
116
                raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0]))
117
118
            # Load traces
119
            self.frames = collections.defaultdict(list)
120
            self.trace = collections.defaultdict(_defaultdict_of_lists)
121
122
            with open(os.path.join(self.root, "VolumeTracings.csv")) as f:
123
                header = f.readline().strip().split(",")
124
                assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"]
125
126
                for line in f:
127
                    filename, x1, y1, x2, y2, frame = line.strip().split(',')
128
                    x1 = float(x1)
129
                    y1 = float(y1)
130
                    x2 = float(x2)
131
                    y2 = float(y2)
132
                    frame = int(frame)
133
                    if frame not in self.trace[filename]:
134
                        self.frames[filename].append(frame)
135
                    self.trace[filename][frame].append((x1, y1, x2, y2))
136
            for filename in self.frames:
137
                for frame in self.frames[filename]:
138
                    self.trace[filename][frame] = np.array(self.trace[filename][frame])
139
140
            # A small number of videos are missing traces; remove these videos
141
            keep = [len(self.frames[f]) >= 2 for f in self.fnames]
142
            self.fnames = [f for (f, k) in zip(self.fnames, keep) if k]
143
            self.outcome = [f for (f, k) in zip(self.outcome, keep) if k]
144
145
    def __getitem__(self, index):
146
        # Find filename of video
147
        if self.split == "EXTERNAL_TEST":
148
            video = os.path.join(self.external_test_location, self.fnames[index])
149
        elif self.split == "CLINICAL_TEST":
150
            video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index])
151
        else:
152
            video = os.path.join(self.root, "Videos", self.fnames[index])
153
154
        # Load video into np.array
155
        video = echonet.utils.loadvideo(video).astype(np.float32)
156
157
        # Add simulated noise (black out random pixels)
158
        # 0 represents black at this point (video has not been normalized yet)
159
        if self.noise is not None:
160
            n = video.shape[1] * video.shape[2] * video.shape[3]
161
            ind = np.random.choice(n, round(self.noise * n), replace=False)
162
            f = ind % video.shape[1]
163
            ind //= video.shape[1]
164
            i = ind % video.shape[2]
165
            ind //= video.shape[2]
166
            j = ind
167
            video[:, f, i, j] = 0
168
169
        # Apply normalization
170
        if isinstance(self.mean, (float, int)):
171
            video -= self.mean
172
        else:
173
            video -= self.mean.reshape(3, 1, 1, 1)
174
175
        if isinstance(self.std, (float, int)):
176
            video /= self.std
177
        else:
178
            video /= self.std.reshape(3, 1, 1, 1)
179
180
        # Set number of frames
181
        c, f, h, w = video.shape
182
        if self.length is None:
183
            # Take as many frames as possible
184
            length = f // self.period
185
        else:
186
            # Take specified number of frames
187
            length = self.length
188
189
        if self.max_length is not None:
190
            # Shorten videos to max_length
191
            length = min(length, self.max_length)
192
193
        if f < length * self.period:
194
            # Pad video with frames filled with zeros if too short
195
            # 0 represents the mean color (dark grey), since this is after normalization
196
            video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1)
197
            c, f, h, w = video.shape  # pylint: disable=E0633
198
199
        if self.clips == "all":
200
            # Take all possible clips of desired length
201
            start = np.arange(f - (length - 1) * self.period)
202
        else:
203
            # Take random clips from video
204
            start = np.random.choice(f - (length - 1) * self.period, self.clips)
205
206
        # Gather targets
207
        target = []
208
        for t in self.target_type:
209
            key = self.fnames[index]
210
            if t == "Filename":
211
                target.append(self.fnames[index])
212
            elif t == "LargeIndex":
213
                # Traces are sorted by cross-sectional area
214
                # Largest (diastolic) frame is last
215
                target.append(np.int(self.frames[key][-1]))
216
            elif t == "SmallIndex":
217
                # Largest (diastolic) frame is first
218
                target.append(np.int(self.frames[key][0]))
219
            elif t == "LargeFrame":
220
                target.append(video[:, self.frames[key][-1], :, :])
221
            elif t == "SmallFrame":
222
                target.append(video[:, self.frames[key][0], :, :])
223
            elif t in ["LargeTrace", "SmallTrace"]:
224
                if t == "LargeTrace":
225
                    t = self.trace[key][self.frames[key][-1]]
226
                else:
227
                    t = self.trace[key][self.frames[key][0]]
228
                x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3]
229
                x = np.concatenate((x1[1:], np.flip(x2[1:])))
230
                y = np.concatenate((y1[1:], np.flip(y2[1:])))
231
232
                r, c = skimage.draw.polygon(np.rint(y).astype(np.int), np.rint(x).astype(np.int), (video.shape[2], video.shape[3]))
233
                mask = np.zeros((video.shape[2], video.shape[3]), np.float32)
234
                mask[r, c] = 1
235
                target.append(mask)
236
            else:
237
                if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST":
238
                    target.append(np.float32(0))
239
                else:
240
                    target.append(np.float32(self.outcome[index][self.header.index(t)]))
241
242
        if target != []:
243
            target = tuple(target) if len(target) > 1 else target[0]
244
            if self.target_transform is not None:
245
                target = self.target_transform(target)
246
247
        # Select clips from video
248
        video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start)
249
        if self.clips == 1:
250
            video = video[0]
251
        else:
252
            video = np.stack(video)
253
254
        if self.pad is not None:
255
            # Add padding of zeros (mean color of videos)
256
            # Crop of original size is taken out
257
            # (Used as augmentation)
258
            c, l, h, w = video.shape
259
            temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype)
260
            temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video  # pylint: disable=E1130
261
            i, j = np.random.randint(0, 2 * self.pad, 2)
262
            video = temp[:, :, i:(i + h), j:(j + w)]
263
264
        return video, target
265
266
    def __len__(self):
267
        return len(self.fnames)
268
269
    def extra_repr(self) -> str:
270
        """Additional information to add at end of __repr__."""
271
        lines = ["Target type: {target_type}", "Split: {split}"]
272
        return '\n'.join(lines).format(**self.__dict__)
273
274
275
def _defaultdict_of_lists():
276
    """Returns a defaultdict of lists.
277
278
    This is used to avoid issues with Windows (if this function is anonymous,
279
    the Echo dataset cannot be used in a dataloader).
280
    """
281
282
    return collections.defaultdict(list)