Diff of /code/utils_model.py [000000] .. [594161]

Switch to unified view

a b/code/utils_model.py
1
"""
2
DeepSlide
3
Using ResNet to train and test.
4
5
Authors: Jason Wei, Behnaz Abdollahi, Saeed Hassanpour, Naofumi Tomita
6
"""
7
8
import operator
9
import random
10
import time
11
from pathlib import Path
12
from typing import (Dict, IO, List, Tuple)
13
14
import numpy as np
15
import pandas as pd
16
import torch
17
import torch.nn as nn
18
import torch.optim as optim
19
import torchvision
20
from PIL import Image
21
from torch.optim import lr_scheduler
22
from torchvision import (datasets, transforms)
23
24
from utils import (get_image_paths, get_subfolder_paths)
25
26
###########################################
27
#             MISC FUNCTIONS              #
28
###########################################
29
30
31
def calculate_confusion_matrix(all_labels: np.ndarray,
32
                               all_predicts: np.ndarray, classes: List[str],
33
                               num_classes: int) -> None:
34
    """
35
    Prints the confusion matrix from the given data.
36
37
    Args:
38
        all_labels: The ground truth labels.
39
        all_predicts: The predicted labels.
40
        classes: Names of the classes in the dataset.
41
        num_classes: Number of classes in the dataset.
42
    """
43
    remap_classes = {x: classes[x] for x in range(num_classes)}
44
45
    # Set print options.
46
    # Sources:
47
    #   1. https://stackoverflow.com/questions/42735541/customized-float-formatting-in-a-pandas-dataframe
48
    #   2. https://stackoverflow.com/questions/11707586/how-do-i-expand-the-output-display-to-see-more-columns-of-a-pandas-dataframe
49
    #   3. https://pandas.pydata.org/pandas-docs/stable/user_guide/style.html
50
    pd.options.display.float_format = "{:.2f}".format
51
    pd.options.display.width = 0
52
53
    actual = pd.Series(pd.Categorical(
54
        pd.Series(all_labels).replace(remap_classes), categories=classes),
55
                       name="Actual")
56
57
    predicted = pd.Series(pd.Categorical(
58
        pd.Series(all_predicts).replace(remap_classes), categories=classes),
59
                          name="Predicted")
60
61
    cm = pd.crosstab(index=actual, columns=predicted, normalize="index", dropna=False)
62
63
    
64
    # cm.style.hide_index() 
65
    # Pandas hide_index method became deprecated since the version 1.4.0,
66
    # should be replaced by:
67
    cm.style.hide()
68
69
    print(cm)
70
71
72
class Random90Rotation:
73
    def __init__(self, degrees: Tuple[int] = None) -> None:
74
        """
75
        Randomly rotate the image for training. Credits to Naofumi Tomita.
76
77
        Args:
78
            degrees: Degrees available for rotation.
79
        """
80
        self.degrees = (0, 90, 180, 270) if (degrees is None) else degrees
81
82
    def __call__(self, im: Image) -> Image:
83
        """
84
        Produces a randomly rotated image every time the instance is called.
85
86
        Args:
87
            im: The image to rotate.
88
89
        Returns:    
90
            Randomly rotated image.
91
        """
92
        return im.rotate(angle=random.sample(population=self.degrees, k=1)[0])
93
94
95
def create_model(num_layers: int, num_classes: int,
96
                 pretrain: bool) -> torchvision.models.resnet.ResNet:
97
    """
98
    Instantiate the ResNet model.
99
100
    Args:
101
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
102
        num_classes: Number of classes in the dataset.
103
        pretrain: Use pretrained ResNet weights.
104
105
    Returns:
106
        The instantiated ResNet model with the requested parameters.
107
    """
108
    assert num_layers in (
109
        18, 34, 50, 101, 152
110
    ), f"Invalid number of ResNet Layers. Must be one of [18, 34, 50, 101, 152] and not {num_layers}"
111
    model_constructor = getattr(torchvision.models, f"resnet{num_layers}")
112
    model = model_constructor(num_classes=num_classes)
113
114
    if pretrain:
115
        pretrained = model_constructor(pretrained=True).state_dict()
116
        if num_classes != pretrained["fc.weight"].size(0):
117
            del pretrained["fc.weight"], pretrained["fc.bias"]
118
        model.load_state_dict(state_dict=pretrained, strict=False)
119
    return model
120
121
122
def get_data_transforms(color_jitter_brightness: float,
123
                        color_jitter_contrast: float,
124
                        color_jitter_saturation: float,
125
                        color_jitter_hue: float, path_mean: List[float],
126
                        path_std: List[float]
127
                        ) -> Dict[str, torchvision.transforms.Compose]:
128
    """
129
    Sets up the dataset transforms for training and validation.
130
131
    Args:
132
        color_jitter_brightness: Random brightness jitter to use in data augmentation for ColorJitter() transform.
133
        color_jitter_contrast: Random contrast jitter to use in data augmentation for ColorJitter() transform.
134
        color_jitter_saturation: Random saturation jitter to use in data augmentation for ColorJitter() transform.
135
        color_jitter_hue: Random hue jitter to use in data augmentation for ColorJitter() transform.
136
        path_mean: Means of the WSIs for each dimension.
137
        path_std: Standard deviations of the WSIs for each dimension.
138
139
    Returns:
140
        A dictionary mapping training and validation strings to data transforms.
141
    """
142
    return {
143
        "train":
144
        transforms.Compose(transforms=[
145
            transforms.ColorJitter(brightness=color_jitter_brightness,
146
                                   contrast=color_jitter_contrast,
147
                                   saturation=color_jitter_saturation,
148
                                   hue=color_jitter_hue),
149
            transforms.RandomHorizontalFlip(),
150
            transforms.RandomVerticalFlip(),
151
            Random90Rotation(),
152
            transforms.ToTensor(),
153
            transforms.Normalize(mean=path_mean, std=path_std)
154
        ]),
155
        "val":
156
        transforms.Compose(transforms=[
157
            transforms.ToTensor(),
158
            transforms.Normalize(mean=path_mean, std=path_std)
159
        ])
160
    }
161
162
163
def print_params(train_folder: Path, num_epochs: int, num_layers: int,
164
                 learning_rate: float, batch_size: int, weight_decay: float,
165
                 learning_rate_decay: float, resume_checkpoint: bool,
166
                 resume_checkpoint_path: Path, save_interval: int,
167
                 checkpoints_folder: Path, pretrain: bool,
168
                 log_csv: Path) -> None:
169
    """
170
    Print the configuration of the model.
171
172
    Args:
173
        train_folder: Location of the automatically built training input folder.
174
        num_epochs: Number of epochs for training.
175
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
176
        learning_rate: Learning rate to use for gradient descent.
177
        batch_size: Mini-batch size to use for training.
178
        weight_decay: Weight decay (L2 penalty) to use in optimizer.
179
        learning_rate_decay: Learning rate decay amount per epoch.
180
        resume_checkpoint: Resume model from checkpoint file.
181
        resume_checkpoint_path: Path to the checkpoint file for resuming training.
182
        save_interval: Number of epochs between saving checkpoints.
183
        checkpoints_folder: Directory to save model checkpoints to.
184
        pretrain: Use pretrained ResNet weights.
185
        log_csv: Name of the CSV file containing the logs.
186
    """
187
    print(f"train_folder: {train_folder}\n"
188
          f"num_epochs: {num_epochs}\n"
189
          f"num_layers: {num_layers}\n"
190
          f"learning_rate: {learning_rate}\n"
191
          f"batch_size: {batch_size}\n"
192
          f"weight_decay: {weight_decay}\n"
193
          f"learning_rate_decay: {learning_rate_decay}\n"
194
          f"resume_checkpoint: {resume_checkpoint}\n"
195
          f"resume_checkpoint_path (only if resume_checkpoint is true): "
196
          f"{resume_checkpoint_path}\n"
197
          f"save_interval: {save_interval}\n"
198
          f"output in checkpoints_folder: {checkpoints_folder}\n"
199
          f"pretrain: {pretrain}\n"
200
          f"log_csv: {log_csv}\n\n")
201
202
203
###########################################
204
#          MAIN TRAIN FUNCTION            #
205
###########################################
206
207
208
def train_helper(model: torchvision.models.resnet.ResNet,
209
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
210
                 dataset_sizes: Dict[str, int],
211
                 criterion: torch.nn.modules.loss, optimizer: torch.optim,
212
                 scheduler: torch.optim.lr_scheduler, num_epochs: int,
213
                 writer: IO, device: torch.device, start_epoch: int,
214
                 batch_size: int, save_interval: int, checkpoints_folder: Path,
215
                 num_layers: int, classes: List[str],
216
                 num_classes: int) -> None:
217
    """
218
    Function for training ResNet.
219
220
    Args:
221
        model: ResNet model for training.
222
        dataloaders: Dataloaders for IO pipeline.
223
        dataset_sizes: Sizes of the training and validation dataset.
224
        criterion: Metric used for calculating loss.
225
        optimizer: Optimizer to use for gradient descent.
226
        scheduler: Scheduler to use for learning rate decay.
227
        start_epoch: Starting epoch for training.
228
        writer: Writer to write logging information.
229
        device: Device to use for running model.
230
        num_epochs: Total number of epochs to train for.
231
        batch_size: Mini-batch size to use for training.
232
        save_interval: Number of epochs between saving checkpoints.
233
        checkpoints_folder: Directory to save model checkpoints to.
234
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
235
        classes: Names of the classes in the dataset.
236
        num_classes: Number of classes in the dataset.
237
    """
238
    since = time.time()
239
240
    # Initialize all the tensors to be used in training and validation.
241
    # Do this outside the loop since it will be written over entirely at each
242
    # epoch and doesn't need to be reallocated each time.
243
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
244
                                   dtype=torch.long).cpu()
245
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
246
                                     dtype=torch.long).cpu()
247
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
248
                                 dtype=torch.long).cpu()
249
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
250
                                   dtype=torch.long).cpu()
251
252
    # Train for specified number of epochs.
253
    for epoch in range(start_epoch, num_epochs):
254
255
        # Training phase.
256
        model.train(mode=True)
257
258
        train_running_loss = 0.0
259
        train_running_corrects = 0
260
261
        # Train over all training data.
262
        for idx, (inputs, labels) in enumerate(dataloaders["train"]):
263
            train_inputs = inputs.to(device=device)
264
            train_labels = labels.to(device=device)
265
            optimizer.zero_grad()
266
267
            # Forward and backpropagation.
268
            with torch.set_grad_enabled(mode=True):
269
                train_outputs = model(train_inputs)
270
                __, train_preds = torch.max(train_outputs, dim=1)
271
                train_loss = criterion(input=train_outputs,
272
                                       target=train_labels)
273
                train_loss.backward()
274
                optimizer.step()
275
276
            # Update training diagnostics.
277
            train_running_loss += train_loss.item() * train_inputs.size(0)
278
            train_running_corrects += torch.sum(
279
                train_preds == train_labels.data, dtype=torch.double)
280
281
            start = idx * batch_size
282
            end = start + batch_size
283
284
            train_all_labels[start:end] = train_labels.detach().cpu()
285
            train_all_predicts[start:end] = train_preds.detach().cpu()
286
287
        calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
288
                                   all_predicts=train_all_predicts.numpy(),
289
                                   classes=classes,
290
                                   num_classes=num_classes)
291
292
        # Store training diagnostics.
293
        train_loss = train_running_loss / dataset_sizes["train"]
294
        train_acc = train_running_corrects / dataset_sizes["train"]
295
296
        if torch.cuda.is_available():
297
            torch.cuda.empty_cache()
298
299
        # Validation phase.
300
        model.train(mode=False)
301
302
        val_running_loss = 0.0
303
        val_running_corrects = 0
304
305
        # Feed forward over all the validation data.
306
        for idx, (val_inputs, val_labels) in enumerate(dataloaders["val"]):
307
            val_inputs = val_inputs.to(device=device)
308
            val_labels = val_labels.to(device=device)
309
310
            # Feed forward.
311
            with torch.set_grad_enabled(mode=False):
312
                val_outputs = model(val_inputs)
313
                _, val_preds = torch.max(val_outputs, dim=1)
314
                val_loss = criterion(input=val_outputs, target=val_labels)
315
316
            # Update validation diagnostics.
317
            val_running_loss += val_loss.item() * val_inputs.size(0)
318
            val_running_corrects += torch.sum(val_preds == val_labels.data,
319
                                              dtype=torch.double)
320
321
            start = idx * batch_size
322
            end = start + batch_size
323
324
            val_all_labels[start:end] = val_labels.detach().cpu()
325
            val_all_predicts[start:end] = val_preds.detach().cpu()
326
327
        calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
328
                                   all_predicts=val_all_predicts.numpy(),
329
                                   classes=classes,
330
                                   num_classes=num_classes)
331
332
        # Store validation diagnostics.
333
        val_loss = val_running_loss / dataset_sizes["val"]
334
        val_acc = val_running_corrects / dataset_sizes["val"]
335
336
        if torch.cuda.is_available():
337
            torch.cuda.empty_cache()
338
339
        scheduler.step()
340
341
        current_lr = None
342
        for group in optimizer.param_groups:
343
            current_lr = group["lr"]
344
345
        # Remaining things related to training.
346
        if epoch % save_interval == 0:
347
            epoch_output_path = checkpoints_folder.joinpath(
348
                f"resnet{num_layers}_e{epoch}_va{val_acc:.5f}.pt")
349
350
            # Confirm the output directory exists.
351
            epoch_output_path.parent.mkdir(parents=True, exist_ok=True)
352
353
            # Save the model as a state dictionary.
354
            torch.save(obj={
355
                "model_state_dict": model.state_dict(),
356
                "optimizer_state_dict": optimizer.state_dict(),
357
                "scheduler_state_dict": scheduler.state_dict(),
358
                "epoch": epoch + 1
359
            },
360
                       f=str(epoch_output_path))
361
362
        writer.write(f"{epoch},{train_loss:.4f},"
363
                     f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")
364
365
        # Print the diagnostics for each epoch.
366
        print(f"Epoch {epoch} with lr "
367
              f"{current_lr:.15f}: "
368
              f"t_loss: {train_loss:.4f} "
369
              f"t_acc: {train_acc:.4f} "
370
              f"v_loss: {val_loss:.4f} "
371
              f"v_acc: {val_acc:.4f}\n")
372
373
    # Print training information at the end.
374
    print(f"\ntraining complete in "
375
          f"{(time.time() - since) // 60:.2f} minutes")
376
377
378
def train_resnet(
379
        train_folder: Path, batch_size: int, num_workers: int,
380
        device: torch.device, classes: List[str], learning_rate: float,
381
        weight_decay: float, learning_rate_decay: float,
382
        resume_checkpoint: bool, resume_checkpoint_path: Path, log_csv: Path,
383
        color_jitter_brightness: float, color_jitter_contrast: float,
384
        color_jitter_hue: float, color_jitter_saturation: float,
385
        path_mean: List[float], path_std: List[float], num_classes: int,
386
        num_layers: int, pretrain: bool, checkpoints_folder: Path,
387
        num_epochs: int, save_interval: int) -> None:
388
    """
389
    Main function for training ResNet.
390
391
    Args:
392
        train_folder: Location of the automatically built training input folder.
393
        batch_size: Mini-batch size to use for training.
394
        num_workers: Number of workers to use for IO.
395
        device: Device to use for running model.
396
        classes: Names of the classes in the dataset.
397
        learning_rate: Learning rate to use for gradient descent.
398
        weight_decay: Weight decay (L2 penalty) to use in optimizer.
399
        learning_rate_decay: Learning rate decay amount per epoch.
400
        resume_checkpoint: Resume model from checkpoint file.
401
        resume_checkpoint_path: Path to the checkpoint file for resuming training.
402
        log_csv: Name of the CSV file containing the logs.
403
        color_jitter_brightness: Random brightness jitter to use in data augmentation for ColorJitter() transform.
404
        color_jitter_contrast: Random contrast jitter to use in data augmentation for ColorJitter() transform.
405
        color_jitter_hue: Random hue jitter to use in data augmentation for ColorJitter() transform.
406
        color_jitter_saturation: Random saturation jitter to use in data augmentation for ColorJitter() transform.
407
        path_mean: Means of the WSIs for each dimension.
408
        path_std: Standard deviations of the WSIs for each dimension.
409
        num_classes: Number of classes in the dataset.
410
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
411
        pretrain: Use pretrained ResNet weights.
412
        checkpoints_folder: Directory to save model checkpoints to.
413
        num_epochs: Number of epochs for training.
414
        save_interval: Number of epochs between saving checkpoints.
415
    """
416
    # Loading in the data.
417
    data_transforms = get_data_transforms(
418
        color_jitter_brightness=color_jitter_brightness,
419
        color_jitter_contrast=color_jitter_contrast,
420
        color_jitter_hue=color_jitter_hue,
421
        color_jitter_saturation=color_jitter_saturation,
422
        path_mean=path_mean,
423
        path_std=path_std)
424
425
    image_datasets = {
426
        x: datasets.ImageFolder(root=str(train_folder.joinpath(x)),
427
                                transform=data_transforms[x])
428
        for x in ("train", "val")
429
    }
430
431
    dataloaders = {
432
        x: torch.utils.data.DataLoader(dataset=image_datasets[x],
433
                                       batch_size=batch_size,
434
                                       shuffle=(x is "train"),
435
                                       num_workers=num_workers)
436
        for x in ("train", "val")
437
    }
438
    dataset_sizes = {x: len(image_datasets[x]) for x in ("train", "val")}
439
440
    print(f"{num_classes} classes: {classes}\n"
441
          f"num train images {len(dataloaders['train']) * batch_size}\n"
442
          f"num val images {len(dataloaders['val']) * batch_size}\n"
443
          f"CUDA is_available: {torch.cuda.is_available()}")
444
445
    model = create_model(num_classes=num_classes,
446
                         num_layers=num_layers,
447
                         pretrain=pretrain)
448
    model = model.to(device=device)
449
    optimizer = optim.Adam(params=model.parameters(),
450
                           lr=learning_rate,
451
                           weight_decay=weight_decay)
452
    scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer,
453
                                           gamma=learning_rate_decay)
454
455
    # Initialize the model.
456
    if resume_checkpoint:
457
        ckpt = torch.load(f=resume_checkpoint_path)
458
        model.load_state_dict(state_dict=ckpt["model_state_dict"])
459
        optimizer.load_state_dict(state_dict=ckpt["optimizer_state_dict"])
460
        scheduler.load_state_dict(state_dict=ckpt["scheduler_state_dict"])
461
        start_epoch = ckpt["epoch"]
462
        print(f"model loaded from {resume_checkpoint_path}")
463
    else:
464
        start_epoch = 0
465
466
    # Print the model hyperparameters.
467
    print_params(batch_size=batch_size,
468
                 checkpoints_folder=checkpoints_folder,
469
                 learning_rate=learning_rate,
470
                 learning_rate_decay=learning_rate_decay,
471
                 log_csv=log_csv,
472
                 num_epochs=num_epochs,
473
                 num_layers=num_layers,
474
                 pretrain=pretrain,
475
                 resume_checkpoint=resume_checkpoint,
476
                 resume_checkpoint_path=resume_checkpoint_path,
477
                 save_interval=save_interval,
478
                 train_folder=train_folder,
479
                 weight_decay=weight_decay)
480
481
    # Logging the model after every epoch.
482
    # Confirm the output directory exists.
483
    log_csv.parent.mkdir(parents=True, exist_ok=True)
484
485
    with log_csv.open(mode="w") as writer:
486
        writer.write("epoch,train_loss,train_acc,val_loss,val_acc\n")
487
        # Train the model.
488
        train_helper(model=model,
489
                     dataloaders=dataloaders,
490
                     dataset_sizes=dataset_sizes,
491
                     criterion=nn.CrossEntropyLoss(),
492
                     optimizer=optimizer,
493
                     scheduler=scheduler,
494
                     start_epoch=start_epoch,
495
                     writer=writer,
496
                     batch_size=batch_size,
497
                     checkpoints_folder=checkpoints_folder,
498
                     device=device,
499
                     num_layers=num_layers,
500
                     save_interval=save_interval,
501
                     num_epochs=num_epochs,
502
                     classes=classes,
503
                     num_classes=num_classes)
504
505
506
###########################################
507
#      MAIN EVALUATION FUNCTION           #
508
###########################################
509
510
511
def parse_val_acc(model_path: Path) -> float:
512
    """
513
    Parse the validation accuracy from the filename.
514
515
    Args:
516
        model_path: The model path to parse for the validation accuracy.
517
518
    Returns:
519
        The parsed validation accuracy.
520
    """
521
    return float(
522
        f"{('.'.join(model_path.name.split('.')[:-1])).split('_')[-1][2:]}")
523
524
525
def get_best_model(checkpoints_folder: Path) -> str:
526
    """
527
    Finds the model with the best validation accuracy.
528
529
    Args:
530
        checkpoints_folder: Folder containing the models to test.
531
532
    Returns:
533
        The location of the model with the best validation accuracy.
534
    """
535
    return max({
536
        model: parse_val_acc(model_path=model)
537
        for model in [m for m in checkpoints_folder.rglob("*.pt") if ".DS_Store" not in str(m)]
538
    }.items(),
539
               key=operator.itemgetter(1))[0]
540
541
542
def get_predictions(patches_eval_folder: Path, output_folder: Path,
543
                    checkpoints_folder: Path, auto_select: bool,
544
                    eval_model: Path, device: torch.device, classes: List[str],
545
                    num_classes: int, path_mean: List[float],
546
                    path_std: List[float], num_layers: int, pretrain: bool,
547
                    batch_size: int, num_workers: int) -> None:
548
    """
549
    Main function for running the model on all of the generated patches.
550
551
    Args:
552
        patches_eval_folder: Folder containing patches to evaluate on.
553
        output_folder: Folder to save the model results to.
554
        checkpoints_folder: Directory to save model checkpoints to.
555
        auto_select: Automatically select the model with the highest validation accuracy,
556
        eval_model: Path to the model with the highest validation accuracy.
557
        device: Device to use for running model.
558
        classes: Names of the classes in the dataset.
559
        num_classes: Number of classes in the dataset.
560
        path_mean: Means of the WSIs for each dimension.
561
        path_std: Standard deviations of the WSIs for each dimension.
562
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
563
        pretrain: Use pretrained ResNet weights.
564
        batch_size: Mini-batch size to use for training.
565
        num_workers: Number of workers to use for IO.
566
    """
567
    # Initialize the model.
568
    model_path = get_best_model(
569
        checkpoints_folder=checkpoints_folder) if auto_select else eval_model
570
571
    model = create_model(num_classes=num_classes,
572
                         num_layers=num_layers,
573
                         pretrain=pretrain)
574
    ckpt = torch.load(f=model_path)
575
    model.load_state_dict(state_dict=ckpt["model_state_dict"])
576
    model = model.to(device=device)
577
578
    model.train(mode=False)
579
    print(f"model loaded from {model_path}")
580
581
    # For outputting the predictions.
582
    class_num_to_class = {i: classes[i] for i in range(num_classes)}
583
584
    start = time.time()
585
    # Load the data for each folder.
586
    image_folders = get_subfolder_paths(folder=patches_eval_folder)
587
588
    # Where we want to write out the predictions.
589
    # Confirm the output directory exists.
590
    output_folder.mkdir(parents=True, exist_ok=True)
591
592
    # For each WSI.
593
    for image_folder in image_folders:
594
595
        # Temporary fix. Need not to make folders with no crops.
596
        try:
597
            # Load the image dataset.
598
            dataloader = torch.utils.data.DataLoader(
599
                dataset=datasets.ImageFolder(
600
                    root=str(image_folder),
601
                    transform=transforms.Compose(transforms=[
602
                        transforms.ToTensor(),
603
                        transforms.Normalize(mean=path_mean, std=path_std)
604
                    ])),
605
                batch_size=batch_size,
606
                shuffle=False,
607
                num_workers=num_workers)
608
        except RuntimeError:
609
            print(
610
                "WARNING: One of the image directories is empty. Skipping this directory."
611
            )
612
            continue
613
614
        num_test_image_windows = len(dataloader) * batch_size
615
616
        # Load the image names so we know the coordinates of the patches we are predicting.
617
        image_folder = image_folder.joinpath(image_folder.name)
618
        window_names = get_image_paths(folder=image_folder)
619
620
        print(f"testing on {num_test_image_windows} crops from {image_folder}")
621
622
        with output_folder.joinpath(f"{image_folder.name}.csv").open(
623
                mode="w") as writer:
624
625
            writer.write("x,y,prediction,confidence\n")
626
627
            # Loop through all of the patches.
628
            for batch_num, (test_inputs, test_labels) in enumerate(dataloader):
629
                batch_window_names = window_names[batch_num *
630
                                                  batch_size:batch_num *
631
                                                  batch_size + batch_size]
632
633
                confidences, test_preds = torch.max(nn.Softmax(dim=1)(model(
634
                    test_inputs.to(device=device))),
635
                                                    dim=1)
636
                for i in range(test_preds.shape[0]):
637
                    # Find coordinates and predicted class.
638
                    xy = batch_window_names[i].name.split(".")[0].split(";")
639
640
                    writer.write(
641
                        f"{','.join([xy[0], xy[1], f'{class_num_to_class[test_preds[i].data.item()]}', f'{confidences[i].data.item():.5f}'])}\n"
642
                    )
643
644
    print(f"time for {patches_eval_folder}: {time.time() - start:.2f} seconds")