|
a |
|
b/echonet/utils/segmentation.py |
|
|
1 |
"""Functions for training and running segmentation.""" |
|
|
2 |
|
|
|
3 |
import math |
|
|
4 |
import os |
|
|
5 |
import time |
|
|
6 |
|
|
|
7 |
import click |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import numpy as np |
|
|
10 |
import scipy.signal |
|
|
11 |
import skimage.draw |
|
|
12 |
import torch |
|
|
13 |
import torchvision |
|
|
14 |
import tqdm |
|
|
15 |
|
|
|
16 |
import echonet |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
@click.command("segmentation") |
|
|
20 |
@click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) |
|
|
21 |
@click.option("--output", type=click.Path(file_okay=False), default=None) |
|
|
22 |
@click.option("--model_name", type=click.Choice( |
|
|
23 |
sorted(name for name in torchvision.models.segmentation.__dict__ |
|
|
24 |
if name.islower() and not name.startswith("__") and callable(torchvision.models.segmentation.__dict__[name]))), |
|
|
25 |
default="deeplabv3_resnet50") |
|
|
26 |
@click.option("--pretrained/--random", default=False) |
|
|
27 |
@click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) |
|
|
28 |
@click.option("--run_test/--skip_test", default=False) |
|
|
29 |
@click.option("--save_video/--skip_video", default=False) |
|
|
30 |
@click.option("--num_epochs", type=int, default=50) |
|
|
31 |
@click.option("--lr", type=float, default=1e-5) |
|
|
32 |
@click.option("--weight_decay", type=float, default=0) |
|
|
33 |
@click.option("--lr_step_period", type=int, default=None) |
|
|
34 |
@click.option("--num_train_patients", type=int, default=None) |
|
|
35 |
@click.option("--num_workers", type=int, default=4) |
|
|
36 |
@click.option("--batch_size", type=int, default=20) |
|
|
37 |
@click.option("--device", type=str, default=None) |
|
|
38 |
@click.option("--seed", type=int, default=0) |
|
|
39 |
def run( |
|
|
40 |
data_dir=None, |
|
|
41 |
output=None, |
|
|
42 |
|
|
|
43 |
model_name="deeplabv3_resnet50", |
|
|
44 |
pretrained=False, |
|
|
45 |
weights=None, |
|
|
46 |
|
|
|
47 |
run_test=False, |
|
|
48 |
save_video=False, |
|
|
49 |
num_epochs=50, |
|
|
50 |
lr=1e-5, |
|
|
51 |
weight_decay=1e-5, |
|
|
52 |
lr_step_period=None, |
|
|
53 |
num_train_patients=None, |
|
|
54 |
num_workers=4, |
|
|
55 |
batch_size=20, |
|
|
56 |
device=None, |
|
|
57 |
seed=0, |
|
|
58 |
): |
|
|
59 |
"""Trains/tests segmentation model. |
|
|
60 |
|
|
|
61 |
Args: |
|
|
62 |
data_dir (str, optional): Directory containing dataset. Defaults to |
|
|
63 |
`echonet.config.DATA_DIR`. |
|
|
64 |
output (str, optional): Directory to place outputs. Defaults to |
|
|
65 |
output/segmentation/<model_name>_<pretrained/random>/. |
|
|
66 |
model_name (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'', |
|
|
67 |
``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101'' |
|
|
68 |
(options are torchvision.models.segmentation.<model_name>) |
|
|
69 |
Defaults to ``deeplabv3_resnet50''. |
|
|
70 |
pretrained (bool, optional): Whether to use pretrained weights for model |
|
|
71 |
Defaults to False. |
|
|
72 |
weights (str, optional): Path to checkpoint containing weights to |
|
|
73 |
initialize model. Defaults to None. |
|
|
74 |
run_test (bool, optional): Whether or not to run on test. |
|
|
75 |
Defaults to False. |
|
|
76 |
save_video (bool, optional): Whether to save videos with segmentations. |
|
|
77 |
Defaults to False. |
|
|
78 |
num_epochs (int, optional): Number of epochs during training |
|
|
79 |
Defaults to 50. |
|
|
80 |
lr (float, optional): Learning rate for SGD |
|
|
81 |
Defaults to 1e-5. |
|
|
82 |
weight_decay (float, optional): Weight decay for SGD |
|
|
83 |
Defaults to 0. |
|
|
84 |
lr_step_period (int or None, optional): Period of learning rate decay |
|
|
85 |
(learning rate is decayed by a multiplicative factor of 0.1) |
|
|
86 |
Defaults to math.inf (never decay learning rate). |
|
|
87 |
num_train_patients (int or None, optional): Number of training patients |
|
|
88 |
for ablations. Defaults to all patients. |
|
|
89 |
num_workers (int, optional): Number of subprocesses to use for data |
|
|
90 |
loading. If 0, the data will be loaded in the main process. |
|
|
91 |
Defaults to 4. |
|
|
92 |
device (str or None, optional): Name of device to run on. Options from |
|
|
93 |
https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device |
|
|
94 |
Defaults to ``cuda'' if available, and ``cpu'' otherwise. |
|
|
95 |
batch_size (int, optional): Number of samples to load per batch |
|
|
96 |
Defaults to 20. |
|
|
97 |
seed (int, optional): Seed for random number generator. Defaults to 0. |
|
|
98 |
""" |
|
|
99 |
|
|
|
100 |
# Seed RNGs |
|
|
101 |
np.random.seed(seed) |
|
|
102 |
torch.manual_seed(seed) |
|
|
103 |
|
|
|
104 |
# Set default output directory |
|
|
105 |
if output is None: |
|
|
106 |
output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random")) |
|
|
107 |
os.makedirs(output, exist_ok=True) |
|
|
108 |
|
|
|
109 |
# Set device for computations |
|
|
110 |
if device is None: |
|
|
111 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
112 |
|
|
|
113 |
# Set up model |
|
|
114 |
model = torchvision.models.segmentation.__dict__[model_name](pretrained=pretrained, aux_loss=False) |
|
|
115 |
|
|
|
116 |
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size) # change number of outputs to 1 |
|
|
117 |
if device.type == "cuda": |
|
|
118 |
model = torch.nn.DataParallel(model) |
|
|
119 |
model.to(device) |
|
|
120 |
|
|
|
121 |
if weights is not None: |
|
|
122 |
checkpoint = torch.load(weights) |
|
|
123 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
124 |
|
|
|
125 |
# Set up optimizer |
|
|
126 |
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) |
|
|
127 |
if lr_step_period is None: |
|
|
128 |
lr_step_period = math.inf |
|
|
129 |
scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) |
|
|
130 |
|
|
|
131 |
# Compute mean and std |
|
|
132 |
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) |
|
|
133 |
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] |
|
|
134 |
kwargs = {"target_type": tasks, |
|
|
135 |
"mean": mean, |
|
|
136 |
"std": std |
|
|
137 |
} |
|
|
138 |
|
|
|
139 |
# Set up datasets and dataloaders |
|
|
140 |
dataset = {} |
|
|
141 |
dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs) |
|
|
142 |
if num_train_patients is not None and len(dataset["train"]) > num_train_patients: |
|
|
143 |
# Subsample patients (used for ablation experiment) |
|
|
144 |
indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) |
|
|
145 |
dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) |
|
|
146 |
dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) |
|
|
147 |
|
|
|
148 |
# Run training and testing loops |
|
|
149 |
with open(os.path.join(output, "log.csv"), "a") as f: |
|
|
150 |
epoch_resume = 0 |
|
|
151 |
bestLoss = float("inf") |
|
|
152 |
try: |
|
|
153 |
# Attempt to load checkpoint |
|
|
154 |
checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) |
|
|
155 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
156 |
optim.load_state_dict(checkpoint['opt_dict']) |
|
|
157 |
scheduler.load_state_dict(checkpoint['scheduler_dict']) |
|
|
158 |
epoch_resume = checkpoint["epoch"] + 1 |
|
|
159 |
bestLoss = checkpoint["best_loss"] |
|
|
160 |
f.write("Resuming from epoch {}\n".format(epoch_resume)) |
|
|
161 |
except FileNotFoundError: |
|
|
162 |
f.write("Starting run from scratch\n") |
|
|
163 |
|
|
|
164 |
for epoch in range(epoch_resume, num_epochs): |
|
|
165 |
print("Epoch #{}".format(epoch), flush=True) |
|
|
166 |
for phase in ['train', 'val']: |
|
|
167 |
start_time = time.time() |
|
|
168 |
for i in range(torch.cuda.device_count()): |
|
|
169 |
torch.cuda.reset_peak_memory_stats(i) |
|
|
170 |
|
|
|
171 |
ds = dataset[phase] |
|
|
172 |
dataloader = torch.utils.data.DataLoader( |
|
|
173 |
ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) |
|
|
174 |
|
|
|
175 |
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, phase == "train", optim, device) |
|
|
176 |
overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum()) |
|
|
177 |
large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum()) |
|
|
178 |
small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum()) |
|
|
179 |
f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, |
|
|
180 |
phase, |
|
|
181 |
loss, |
|
|
182 |
overall_dice, |
|
|
183 |
large_dice, |
|
|
184 |
small_dice, |
|
|
185 |
time.time() - start_time, |
|
|
186 |
large_inter.size, |
|
|
187 |
sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), |
|
|
188 |
sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), |
|
|
189 |
batch_size)) |
|
|
190 |
f.flush() |
|
|
191 |
scheduler.step() |
|
|
192 |
|
|
|
193 |
# Save checkpoint |
|
|
194 |
save = { |
|
|
195 |
'epoch': epoch, |
|
|
196 |
'state_dict': model.state_dict(), |
|
|
197 |
'best_loss': bestLoss, |
|
|
198 |
'loss': loss, |
|
|
199 |
'opt_dict': optim.state_dict(), |
|
|
200 |
'scheduler_dict': scheduler.state_dict(), |
|
|
201 |
} |
|
|
202 |
torch.save(save, os.path.join(output, "checkpoint.pt")) |
|
|
203 |
if loss < bestLoss: |
|
|
204 |
torch.save(save, os.path.join(output, "best.pt")) |
|
|
205 |
bestLoss = loss |
|
|
206 |
|
|
|
207 |
# Load best weights |
|
|
208 |
if num_epochs != 0: |
|
|
209 |
checkpoint = torch.load(os.path.join(output, "best.pt")) |
|
|
210 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
211 |
f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) |
|
|
212 |
|
|
|
213 |
if run_test: |
|
|
214 |
# Run on validation and test |
|
|
215 |
for split in ["val", "test"]: |
|
|
216 |
dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs) |
|
|
217 |
dataloader = torch.utils.data.DataLoader(dataset, |
|
|
218 |
batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) |
|
|
219 |
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, False, None, device) |
|
|
220 |
|
|
|
221 |
overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter) |
|
|
222 |
large_dice = 2 * large_inter / (large_union + large_inter) |
|
|
223 |
small_dice = 2 * small_inter / (small_union + small_inter) |
|
|
224 |
with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g: |
|
|
225 |
g.write("Filename, Overall, Large, Small\n") |
|
|
226 |
for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice): |
|
|
227 |
g.write("{},{},{},{}\n".format(filename, overall, large, small)) |
|
|
228 |
|
|
|
229 |
f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient))) |
|
|
230 |
f.write("{} dice (large): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient))) |
|
|
231 |
f.write("{} dice (small): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient))) |
|
|
232 |
f.flush() |
|
|
233 |
|
|
|
234 |
# Saving videos with segmentations |
|
|
235 |
dataset = echonet.datasets.Echo(root=data_dir, split="test", |
|
|
236 |
target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate |
|
|
237 |
mean=mean, std=std, # Normalization |
|
|
238 |
length=None, max_length=None, period=1 # Take all frames |
|
|
239 |
) |
|
|
240 |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn) |
|
|
241 |
|
|
|
242 |
# Save videos with segmentation |
|
|
243 |
if save_video and not all(os.path.isfile(os.path.join(output, "videos", f)) for f in dataloader.dataset.fnames): |
|
|
244 |
# Only run if missing videos |
|
|
245 |
|
|
|
246 |
model.eval() |
|
|
247 |
|
|
|
248 |
os.makedirs(os.path.join(output, "videos"), exist_ok=True) |
|
|
249 |
os.makedirs(os.path.join(output, "size"), exist_ok=True) |
|
|
250 |
echonet.utils.latexify() |
|
|
251 |
|
|
|
252 |
with torch.no_grad(): |
|
|
253 |
with open(os.path.join(output, "size.csv"), "w") as g: |
|
|
254 |
g.write("Filename,Frame,Size,HumanLarge,HumanSmall,ComputerSmall\n") |
|
|
255 |
for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader): |
|
|
256 |
# Run segmentation model on blocks of frames one-by-one |
|
|
257 |
# The whole concatenated video may be too long to run together |
|
|
258 |
y = np.concatenate([model(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)]) |
|
|
259 |
|
|
|
260 |
start = 0 |
|
|
261 |
x = x.numpy() |
|
|
262 |
for (i, (filename, offset)) in enumerate(zip(filenames, length)): |
|
|
263 |
# Extract one video and segmentation predictions |
|
|
264 |
video = x[start:(start + offset), ...] |
|
|
265 |
logit = y[start:(start + offset), 0, :, :] |
|
|
266 |
|
|
|
267 |
# Un-normalize video |
|
|
268 |
video *= std.reshape(1, 3, 1, 1) |
|
|
269 |
video += mean.reshape(1, 3, 1, 1) |
|
|
270 |
|
|
|
271 |
# Get frames, channels, height, and width |
|
|
272 |
f, c, h, w = video.shape # pylint: disable=W0612 |
|
|
273 |
assert c == 3 |
|
|
274 |
|
|
|
275 |
# Put two copies of the video side by side |
|
|
276 |
video = np.concatenate((video, video), 3) |
|
|
277 |
|
|
|
278 |
# If a pixel is in the segmentation, saturate blue channel |
|
|
279 |
# Leave alone otherwise |
|
|
280 |
video[:, 0, :, w:] = np.maximum(255. * (logit > 0), video[:, 0, :, w:]) # pylint: disable=E1111 |
|
|
281 |
|
|
|
282 |
# Add blank canvas under pair of videos |
|
|
283 |
video = np.concatenate((video, np.zeros_like(video)), 2) |
|
|
284 |
|
|
|
285 |
# Compute size of segmentation per frame |
|
|
286 |
size = (logit > 0).sum((1, 2)) |
|
|
287 |
|
|
|
288 |
# Identify systole frames with peak detection |
|
|
289 |
trim_min = sorted(size)[round(len(size) ** 0.05)] |
|
|
290 |
trim_max = sorted(size)[round(len(size) ** 0.95)] |
|
|
291 |
trim_range = trim_max - trim_min |
|
|
292 |
systole = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0]) |
|
|
293 |
|
|
|
294 |
# Write sizes and frames to file |
|
|
295 |
for (frame, s) in enumerate(size): |
|
|
296 |
g.write("{},{},{},{},{},{}\n".format(filename, frame, s, 1 if frame == large_index[i] else 0, 1 if frame == small_index[i] else 0, 1 if frame in systole else 0)) |
|
|
297 |
|
|
|
298 |
# Plot sizes |
|
|
299 |
fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3)) |
|
|
300 |
plt.scatter(np.arange(size.shape[0]) / 50, size, s=1) |
|
|
301 |
ylim = plt.ylim() |
|
|
302 |
for s in systole: |
|
|
303 |
plt.plot(np.array([s, s]) / 50, ylim, linewidth=1) |
|
|
304 |
plt.ylim(ylim) |
|
|
305 |
plt.title(os.path.splitext(filename)[0]) |
|
|
306 |
plt.xlabel("Seconds") |
|
|
307 |
plt.ylabel("Size (pixels)") |
|
|
308 |
plt.tight_layout() |
|
|
309 |
plt.savefig(os.path.join(output, "size", os.path.splitext(filename)[0] + ".pdf")) |
|
|
310 |
plt.close(fig) |
|
|
311 |
|
|
|
312 |
# Normalize size to [0, 1] |
|
|
313 |
size -= size.min() |
|
|
314 |
size = size / size.max() |
|
|
315 |
size = 1 - size |
|
|
316 |
|
|
|
317 |
# Iterate the frames in this video |
|
|
318 |
for (f, s) in enumerate(size): |
|
|
319 |
|
|
|
320 |
# On all frames, mark a pixel for the size of the frame |
|
|
321 |
video[:, :, int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))] = 255. |
|
|
322 |
|
|
|
323 |
if f in systole: |
|
|
324 |
# If frame is computer-selected systole, mark with a line |
|
|
325 |
video[:, :, 115:224, int(round(f / len(size) * 200 + 10))] = 255. |
|
|
326 |
|
|
|
327 |
def dash(start, stop, on=10, off=10): |
|
|
328 |
buf = [] |
|
|
329 |
x = start |
|
|
330 |
while x < stop: |
|
|
331 |
buf.extend(range(x, x + on)) |
|
|
332 |
x += on |
|
|
333 |
x += off |
|
|
334 |
buf = np.array(buf) |
|
|
335 |
buf = buf[buf < stop] |
|
|
336 |
return buf |
|
|
337 |
d = dash(115, 224) |
|
|
338 |
|
|
|
339 |
if f == large_index[i]: |
|
|
340 |
# If frame is human-selected diastole, mark with green dashed line on all frames |
|
|
341 |
video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 225, 0]).reshape((1, 3, 1)) |
|
|
342 |
if f == small_index[i]: |
|
|
343 |
# If frame is human-selected systole, mark with red dashed line on all frames |
|
|
344 |
video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 0, 225]).reshape((1, 3, 1)) |
|
|
345 |
|
|
|
346 |
# Get pixels for a circle centered on the pixel |
|
|
347 |
r, c = skimage.draw.disk((int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))), 4.1) |
|
|
348 |
|
|
|
349 |
# On the frame that's being shown, put a circle over the pixel |
|
|
350 |
video[f, :, r, c] = 255. |
|
|
351 |
|
|
|
352 |
# Rearrange dimensions and save |
|
|
353 |
video = video.transpose(1, 0, 2, 3) |
|
|
354 |
video = video.astype(np.uint8) |
|
|
355 |
echonet.utils.savevideo(os.path.join(output, "videos", filename), video, 50) |
|
|
356 |
|
|
|
357 |
# Move to next video |
|
|
358 |
start += offset |
|
|
359 |
|
|
|
360 |
|
|
|
361 |
def run_epoch(model, dataloader, train, optim, device): |
|
|
362 |
"""Run one epoch of training/evaluation for segmentation. |
|
|
363 |
|
|
|
364 |
Args: |
|
|
365 |
model (torch.nn.Module): Model to train/evaulate. |
|
|
366 |
dataloder (torch.utils.data.DataLoader): Dataloader for dataset. |
|
|
367 |
train (bool): Whether or not to train model. |
|
|
368 |
optim (torch.optim.Optimizer): Optimizer |
|
|
369 |
device (torch.device): Device to run on |
|
|
370 |
""" |
|
|
371 |
|
|
|
372 |
total = 0. |
|
|
373 |
n = 0 |
|
|
374 |
|
|
|
375 |
pos = 0 |
|
|
376 |
neg = 0 |
|
|
377 |
pos_pix = 0 |
|
|
378 |
neg_pix = 0 |
|
|
379 |
|
|
|
380 |
model.train(train) |
|
|
381 |
|
|
|
382 |
large_inter = 0 |
|
|
383 |
large_union = 0 |
|
|
384 |
small_inter = 0 |
|
|
385 |
small_union = 0 |
|
|
386 |
large_inter_list = [] |
|
|
387 |
large_union_list = [] |
|
|
388 |
small_inter_list = [] |
|
|
389 |
small_union_list = [] |
|
|
390 |
|
|
|
391 |
with torch.set_grad_enabled(train): |
|
|
392 |
with tqdm.tqdm(total=len(dataloader)) as pbar: |
|
|
393 |
for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader: |
|
|
394 |
# Count number of pixels in/out of human segmentation |
|
|
395 |
pos += (large_trace == 1).sum().item() |
|
|
396 |
pos += (small_trace == 1).sum().item() |
|
|
397 |
neg += (large_trace == 0).sum().item() |
|
|
398 |
neg += (small_trace == 0).sum().item() |
|
|
399 |
|
|
|
400 |
# Count number of pixels in/out of computer segmentation |
|
|
401 |
pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy() |
|
|
402 |
pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy() |
|
|
403 |
neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy() |
|
|
404 |
neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy() |
|
|
405 |
|
|
|
406 |
# Run prediction for diastolic frames and compute loss |
|
|
407 |
large_frame = large_frame.to(device) |
|
|
408 |
large_trace = large_trace.to(device) |
|
|
409 |
y_large = model(large_frame)["out"] |
|
|
410 |
loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum") |
|
|
411 |
# Compute pixel intersection and union between human and computer segmentations |
|
|
412 |
large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() |
|
|
413 |
large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() |
|
|
414 |
large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) |
|
|
415 |
large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) |
|
|
416 |
|
|
|
417 |
# Run prediction for systolic frames and compute loss |
|
|
418 |
small_frame = small_frame.to(device) |
|
|
419 |
small_trace = small_trace.to(device) |
|
|
420 |
y_small = model(small_frame)["out"] |
|
|
421 |
loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum") |
|
|
422 |
# Compute pixel intersection and union between human and computer segmentations |
|
|
423 |
small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() |
|
|
424 |
small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() |
|
|
425 |
small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) |
|
|
426 |
small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) |
|
|
427 |
|
|
|
428 |
# Take gradient step if training |
|
|
429 |
loss = (loss_large + loss_small) / 2 |
|
|
430 |
if train: |
|
|
431 |
optim.zero_grad() |
|
|
432 |
loss.backward() |
|
|
433 |
optim.step() |
|
|
434 |
|
|
|
435 |
# Accumulate losses and compute baselines |
|
|
436 |
total += loss.item() |
|
|
437 |
n += large_trace.size(0) |
|
|
438 |
p = pos / (pos + neg) |
|
|
439 |
p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2) |
|
|
440 |
|
|
|
441 |
# Show info on process bar |
|
|
442 |
pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter))) |
|
|
443 |
pbar.update() |
|
|
444 |
|
|
|
445 |
large_inter_list = np.array(large_inter_list) |
|
|
446 |
large_union_list = np.array(large_union_list) |
|
|
447 |
small_inter_list = np.array(small_inter_list) |
|
|
448 |
small_union_list = np.array(small_union_list) |
|
|
449 |
|
|
|
450 |
return (total / n / 112 / 112, |
|
|
451 |
large_inter_list, |
|
|
452 |
large_union_list, |
|
|
453 |
small_inter_list, |
|
|
454 |
small_union_list, |
|
|
455 |
) |
|
|
456 |
|
|
|
457 |
|
|
|
458 |
def _video_collate_fn(x): |
|
|
459 |
"""Collate function for Pytorch dataloader to merge multiple videos. |
|
|
460 |
|
|
|
461 |
This function should be used in a dataloader for a dataset that returns |
|
|
462 |
a video as the first element, along with some (non-zero) tuple of |
|
|
463 |
targets. Then, the input x is a list of tuples: |
|
|
464 |
- x[i][0] is the i-th video in the batch |
|
|
465 |
- x[i][1] are the targets for the i-th video |
|
|
466 |
|
|
|
467 |
This function returns a 3-tuple: |
|
|
468 |
- The first element is the videos concatenated along the frames |
|
|
469 |
dimension. This is done so that videos of different lengths can be |
|
|
470 |
processed together (tensors cannot be "jagged", so we cannot have |
|
|
471 |
a dimension for video, and another for frames). |
|
|
472 |
- The second element is contains the targets with no modification. |
|
|
473 |
- The third element is a list of the lengths of the videos in frames. |
|
|
474 |
""" |
|
|
475 |
video, target = zip(*x) # Extract the videos and targets |
|
|
476 |
|
|
|
477 |
# ``video'' is a tuple of length ``batch_size'' |
|
|
478 |
# Each element has shape (channels=3, frames, height, width) |
|
|
479 |
# height and width are expected to be the same across videos, but |
|
|
480 |
# frames can be different. |
|
|
481 |
|
|
|
482 |
# ``target'' is also a tuple of length ``batch_size'' |
|
|
483 |
# Each element is a tuple of the targets for the item. |
|
|
484 |
|
|
|
485 |
i = list(map(lambda t: t.shape[1], video)) # Extract lengths of videos in frames |
|
|
486 |
|
|
|
487 |
# This contatenates the videos along the the frames dimension (basically |
|
|
488 |
# playing the videos one after another). The frames dimension is then |
|
|
489 |
# moved to be first. |
|
|
490 |
# Resulting shape is (total frames, channels=3, height, width) |
|
|
491 |
video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1)) |
|
|
492 |
|
|
|
493 |
# Swap dimensions (approximately a transpose) |
|
|
494 |
# Before: target[i][j] is the j-th target of element i |
|
|
495 |
# After: target[i][j] is the i-th target of element j |
|
|
496 |
target = zip(*target) |
|
|
497 |
|
|
|
498 |
return video, target, i |