Diff of /code/utils_model.py [000000] .. [594161]

Switch to side-by-side view

--- a
+++ b/code/utils_model.py
@@ -0,0 +1,644 @@
+"""
+DeepSlide
+Using ResNet to train and test.
+
+Authors: Jason Wei, Behnaz Abdollahi, Saeed Hassanpour, Naofumi Tomita
+"""
+
+import operator
+import random
+import time
+from pathlib import Path
+from typing import (Dict, IO, List, Tuple)
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torchvision
+from PIL import Image
+from torch.optim import lr_scheduler
+from torchvision import (datasets, transforms)
+
+from utils import (get_image_paths, get_subfolder_paths)
+
+###########################################
+#             MISC FUNCTIONS              #
+###########################################
+
+
+def calculate_confusion_matrix(all_labels: np.ndarray,
+                               all_predicts: np.ndarray, classes: List[str],
+                               num_classes: int) -> None:
+    """
+    Prints the confusion matrix from the given data.
+
+    Args:
+        all_labels: The ground truth labels.
+        all_predicts: The predicted labels.
+        classes: Names of the classes in the dataset.
+        num_classes: Number of classes in the dataset.
+    """
+    remap_classes = {x: classes[x] for x in range(num_classes)}
+
+    # Set print options.
+    # Sources:
+    #   1. https://stackoverflow.com/questions/42735541/customized-float-formatting-in-a-pandas-dataframe
+    #   2. https://stackoverflow.com/questions/11707586/how-do-i-expand-the-output-display-to-see-more-columns-of-a-pandas-dataframe
+    #   3. https://pandas.pydata.org/pandas-docs/stable/user_guide/style.html
+    pd.options.display.float_format = "{:.2f}".format
+    pd.options.display.width = 0
+
+    actual = pd.Series(pd.Categorical(
+        pd.Series(all_labels).replace(remap_classes), categories=classes),
+                       name="Actual")
+
+    predicted = pd.Series(pd.Categorical(
+        pd.Series(all_predicts).replace(remap_classes), categories=classes),
+                          name="Predicted")
+
+    cm = pd.crosstab(index=actual, columns=predicted, normalize="index", dropna=False)
+
+    
+    # cm.style.hide_index() 
+    # Pandas hide_index method became deprecated since the version 1.4.0,
+    # should be replaced by:
+    cm.style.hide()
+
+    print(cm)
+
+
+class Random90Rotation:
+    def __init__(self, degrees: Tuple[int] = None) -> None:
+        """
+        Randomly rotate the image for training. Credits to Naofumi Tomita.
+
+        Args:
+            degrees: Degrees available for rotation.
+        """
+        self.degrees = (0, 90, 180, 270) if (degrees is None) else degrees
+
+    def __call__(self, im: Image) -> Image:
+        """
+        Produces a randomly rotated image every time the instance is called.
+
+        Args:
+            im: The image to rotate.
+
+        Returns:    
+            Randomly rotated image.
+        """
+        return im.rotate(angle=random.sample(population=self.degrees, k=1)[0])
+
+
+def create_model(num_layers: int, num_classes: int,
+                 pretrain: bool) -> torchvision.models.resnet.ResNet:
+    """
+    Instantiate the ResNet model.
+
+    Args:
+        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
+        num_classes: Number of classes in the dataset.
+        pretrain: Use pretrained ResNet weights.
+
+    Returns:
+        The instantiated ResNet model with the requested parameters.
+    """
+    assert num_layers in (
+        18, 34, 50, 101, 152
+    ), f"Invalid number of ResNet Layers. Must be one of [18, 34, 50, 101, 152] and not {num_layers}"
+    model_constructor = getattr(torchvision.models, f"resnet{num_layers}")
+    model = model_constructor(num_classes=num_classes)
+
+    if pretrain:
+        pretrained = model_constructor(pretrained=True).state_dict()
+        if num_classes != pretrained["fc.weight"].size(0):
+            del pretrained["fc.weight"], pretrained["fc.bias"]
+        model.load_state_dict(state_dict=pretrained, strict=False)
+    return model
+
+
+def get_data_transforms(color_jitter_brightness: float,
+                        color_jitter_contrast: float,
+                        color_jitter_saturation: float,
+                        color_jitter_hue: float, path_mean: List[float],
+                        path_std: List[float]
+                        ) -> Dict[str, torchvision.transforms.Compose]:
+    """
+    Sets up the dataset transforms for training and validation.
+
+    Args:
+        color_jitter_brightness: Random brightness jitter to use in data augmentation for ColorJitter() transform.
+        color_jitter_contrast: Random contrast jitter to use in data augmentation for ColorJitter() transform.
+        color_jitter_saturation: Random saturation jitter to use in data augmentation for ColorJitter() transform.
+        color_jitter_hue: Random hue jitter to use in data augmentation for ColorJitter() transform.
+        path_mean: Means of the WSIs for each dimension.
+        path_std: Standard deviations of the WSIs for each dimension.
+
+    Returns:
+        A dictionary mapping training and validation strings to data transforms.
+    """
+    return {
+        "train":
+        transforms.Compose(transforms=[
+            transforms.ColorJitter(brightness=color_jitter_brightness,
+                                   contrast=color_jitter_contrast,
+                                   saturation=color_jitter_saturation,
+                                   hue=color_jitter_hue),
+            transforms.RandomHorizontalFlip(),
+            transforms.RandomVerticalFlip(),
+            Random90Rotation(),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=path_mean, std=path_std)
+        ]),
+        "val":
+        transforms.Compose(transforms=[
+            transforms.ToTensor(),
+            transforms.Normalize(mean=path_mean, std=path_std)
+        ])
+    }
+
+
+def print_params(train_folder: Path, num_epochs: int, num_layers: int,
+                 learning_rate: float, batch_size: int, weight_decay: float,
+                 learning_rate_decay: float, resume_checkpoint: bool,
+                 resume_checkpoint_path: Path, save_interval: int,
+                 checkpoints_folder: Path, pretrain: bool,
+                 log_csv: Path) -> None:
+    """
+    Print the configuration of the model.
+
+    Args:
+        train_folder: Location of the automatically built training input folder.
+        num_epochs: Number of epochs for training.
+        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
+        learning_rate: Learning rate to use for gradient descent.
+        batch_size: Mini-batch size to use for training.
+        weight_decay: Weight decay (L2 penalty) to use in optimizer.
+        learning_rate_decay: Learning rate decay amount per epoch.
+        resume_checkpoint: Resume model from checkpoint file.
+        resume_checkpoint_path: Path to the checkpoint file for resuming training.
+        save_interval: Number of epochs between saving checkpoints.
+        checkpoints_folder: Directory to save model checkpoints to.
+        pretrain: Use pretrained ResNet weights.
+        log_csv: Name of the CSV file containing the logs.
+    """
+    print(f"train_folder: {train_folder}\n"
+          f"num_epochs: {num_epochs}\n"
+          f"num_layers: {num_layers}\n"
+          f"learning_rate: {learning_rate}\n"
+          f"batch_size: {batch_size}\n"
+          f"weight_decay: {weight_decay}\n"
+          f"learning_rate_decay: {learning_rate_decay}\n"
+          f"resume_checkpoint: {resume_checkpoint}\n"
+          f"resume_checkpoint_path (only if resume_checkpoint is true): "
+          f"{resume_checkpoint_path}\n"
+          f"save_interval: {save_interval}\n"
+          f"output in checkpoints_folder: {checkpoints_folder}\n"
+          f"pretrain: {pretrain}\n"
+          f"log_csv: {log_csv}\n\n")
+
+
+###########################################
+#          MAIN TRAIN FUNCTION            #
+###########################################
+
+
+def train_helper(model: torchvision.models.resnet.ResNet,
+                 dataloaders: Dict[str, torch.utils.data.DataLoader],
+                 dataset_sizes: Dict[str, int],
+                 criterion: torch.nn.modules.loss, optimizer: torch.optim,
+                 scheduler: torch.optim.lr_scheduler, num_epochs: int,
+                 writer: IO, device: torch.device, start_epoch: int,
+                 batch_size: int, save_interval: int, checkpoints_folder: Path,
+                 num_layers: int, classes: List[str],
+                 num_classes: int) -> None:
+    """
+    Function for training ResNet.
+
+    Args:
+        model: ResNet model for training.
+        dataloaders: Dataloaders for IO pipeline.
+        dataset_sizes: Sizes of the training and validation dataset.
+        criterion: Metric used for calculating loss.
+        optimizer: Optimizer to use for gradient descent.
+        scheduler: Scheduler to use for learning rate decay.
+        start_epoch: Starting epoch for training.
+        writer: Writer to write logging information.
+        device: Device to use for running model.
+        num_epochs: Total number of epochs to train for.
+        batch_size: Mini-batch size to use for training.
+        save_interval: Number of epochs between saving checkpoints.
+        checkpoints_folder: Directory to save model checkpoints to.
+        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
+        classes: Names of the classes in the dataset.
+        num_classes: Number of classes in the dataset.
+    """
+    since = time.time()
+
+    # Initialize all the tensors to be used in training and validation.
+    # Do this outside the loop since it will be written over entirely at each
+    # epoch and doesn't need to be reallocated each time.
+    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
+                                   dtype=torch.long).cpu()
+    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
+                                     dtype=torch.long).cpu()
+    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
+                                 dtype=torch.long).cpu()
+    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
+                                   dtype=torch.long).cpu()
+
+    # Train for specified number of epochs.
+    for epoch in range(start_epoch, num_epochs):
+
+        # Training phase.
+        model.train(mode=True)
+
+        train_running_loss = 0.0
+        train_running_corrects = 0
+
+        # Train over all training data.
+        for idx, (inputs, labels) in enumerate(dataloaders["train"]):
+            train_inputs = inputs.to(device=device)
+            train_labels = labels.to(device=device)
+            optimizer.zero_grad()
+
+            # Forward and backpropagation.
+            with torch.set_grad_enabled(mode=True):
+                train_outputs = model(train_inputs)
+                __, train_preds = torch.max(train_outputs, dim=1)
+                train_loss = criterion(input=train_outputs,
+                                       target=train_labels)
+                train_loss.backward()
+                optimizer.step()
+
+            # Update training diagnostics.
+            train_running_loss += train_loss.item() * train_inputs.size(0)
+            train_running_corrects += torch.sum(
+                train_preds == train_labels.data, dtype=torch.double)
+
+            start = idx * batch_size
+            end = start + batch_size
+
+            train_all_labels[start:end] = train_labels.detach().cpu()
+            train_all_predicts[start:end] = train_preds.detach().cpu()
+
+        calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
+                                   all_predicts=train_all_predicts.numpy(),
+                                   classes=classes,
+                                   num_classes=num_classes)
+
+        # Store training diagnostics.
+        train_loss = train_running_loss / dataset_sizes["train"]
+        train_acc = train_running_corrects / dataset_sizes["train"]
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+        # Validation phase.
+        model.train(mode=False)
+
+        val_running_loss = 0.0
+        val_running_corrects = 0
+
+        # Feed forward over all the validation data.
+        for idx, (val_inputs, val_labels) in enumerate(dataloaders["val"]):
+            val_inputs = val_inputs.to(device=device)
+            val_labels = val_labels.to(device=device)
+
+            # Feed forward.
+            with torch.set_grad_enabled(mode=False):
+                val_outputs = model(val_inputs)
+                _, val_preds = torch.max(val_outputs, dim=1)
+                val_loss = criterion(input=val_outputs, target=val_labels)
+
+            # Update validation diagnostics.
+            val_running_loss += val_loss.item() * val_inputs.size(0)
+            val_running_corrects += torch.sum(val_preds == val_labels.data,
+                                              dtype=torch.double)
+
+            start = idx * batch_size
+            end = start + batch_size
+
+            val_all_labels[start:end] = val_labels.detach().cpu()
+            val_all_predicts[start:end] = val_preds.detach().cpu()
+
+        calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
+                                   all_predicts=val_all_predicts.numpy(),
+                                   classes=classes,
+                                   num_classes=num_classes)
+
+        # Store validation diagnostics.
+        val_loss = val_running_loss / dataset_sizes["val"]
+        val_acc = val_running_corrects / dataset_sizes["val"]
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+        scheduler.step()
+
+        current_lr = None
+        for group in optimizer.param_groups:
+            current_lr = group["lr"]
+
+        # Remaining things related to training.
+        if epoch % save_interval == 0:
+            epoch_output_path = checkpoints_folder.joinpath(
+                f"resnet{num_layers}_e{epoch}_va{val_acc:.5f}.pt")
+
+            # Confirm the output directory exists.
+            epoch_output_path.parent.mkdir(parents=True, exist_ok=True)
+
+            # Save the model as a state dictionary.
+            torch.save(obj={
+                "model_state_dict": model.state_dict(),
+                "optimizer_state_dict": optimizer.state_dict(),
+                "scheduler_state_dict": scheduler.state_dict(),
+                "epoch": epoch + 1
+            },
+                       f=str(epoch_output_path))
+
+        writer.write(f"{epoch},{train_loss:.4f},"
+                     f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")
+
+        # Print the diagnostics for each epoch.
+        print(f"Epoch {epoch} with lr "
+              f"{current_lr:.15f}: "
+              f"t_loss: {train_loss:.4f} "
+              f"t_acc: {train_acc:.4f} "
+              f"v_loss: {val_loss:.4f} "
+              f"v_acc: {val_acc:.4f}\n")
+
+    # Print training information at the end.
+    print(f"\ntraining complete in "
+          f"{(time.time() - since) // 60:.2f} minutes")
+
+
+def train_resnet(
+        train_folder: Path, batch_size: int, num_workers: int,
+        device: torch.device, classes: List[str], learning_rate: float,
+        weight_decay: float, learning_rate_decay: float,
+        resume_checkpoint: bool, resume_checkpoint_path: Path, log_csv: Path,
+        color_jitter_brightness: float, color_jitter_contrast: float,
+        color_jitter_hue: float, color_jitter_saturation: float,
+        path_mean: List[float], path_std: List[float], num_classes: int,
+        num_layers: int, pretrain: bool, checkpoints_folder: Path,
+        num_epochs: int, save_interval: int) -> None:
+    """
+    Main function for training ResNet.
+
+    Args:
+        train_folder: Location of the automatically built training input folder.
+        batch_size: Mini-batch size to use for training.
+        num_workers: Number of workers to use for IO.
+        device: Device to use for running model.
+        classes: Names of the classes in the dataset.
+        learning_rate: Learning rate to use for gradient descent.
+        weight_decay: Weight decay (L2 penalty) to use in optimizer.
+        learning_rate_decay: Learning rate decay amount per epoch.
+        resume_checkpoint: Resume model from checkpoint file.
+        resume_checkpoint_path: Path to the checkpoint file for resuming training.
+        log_csv: Name of the CSV file containing the logs.
+        color_jitter_brightness: Random brightness jitter to use in data augmentation for ColorJitter() transform.
+        color_jitter_contrast: Random contrast jitter to use in data augmentation for ColorJitter() transform.
+        color_jitter_hue: Random hue jitter to use in data augmentation for ColorJitter() transform.
+        color_jitter_saturation: Random saturation jitter to use in data augmentation for ColorJitter() transform.
+        path_mean: Means of the WSIs for each dimension.
+        path_std: Standard deviations of the WSIs for each dimension.
+        num_classes: Number of classes in the dataset.
+        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
+        pretrain: Use pretrained ResNet weights.
+        checkpoints_folder: Directory to save model checkpoints to.
+        num_epochs: Number of epochs for training.
+        save_interval: Number of epochs between saving checkpoints.
+    """
+    # Loading in the data.
+    data_transforms = get_data_transforms(
+        color_jitter_brightness=color_jitter_brightness,
+        color_jitter_contrast=color_jitter_contrast,
+        color_jitter_hue=color_jitter_hue,
+        color_jitter_saturation=color_jitter_saturation,
+        path_mean=path_mean,
+        path_std=path_std)
+
+    image_datasets = {
+        x: datasets.ImageFolder(root=str(train_folder.joinpath(x)),
+                                transform=data_transforms[x])
+        for x in ("train", "val")
+    }
+
+    dataloaders = {
+        x: torch.utils.data.DataLoader(dataset=image_datasets[x],
+                                       batch_size=batch_size,
+                                       shuffle=(x is "train"),
+                                       num_workers=num_workers)
+        for x in ("train", "val")
+    }
+    dataset_sizes = {x: len(image_datasets[x]) for x in ("train", "val")}
+
+    print(f"{num_classes} classes: {classes}\n"
+          f"num train images {len(dataloaders['train']) * batch_size}\n"
+          f"num val images {len(dataloaders['val']) * batch_size}\n"
+          f"CUDA is_available: {torch.cuda.is_available()}")
+
+    model = create_model(num_classes=num_classes,
+                         num_layers=num_layers,
+                         pretrain=pretrain)
+    model = model.to(device=device)
+    optimizer = optim.Adam(params=model.parameters(),
+                           lr=learning_rate,
+                           weight_decay=weight_decay)
+    scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer,
+                                           gamma=learning_rate_decay)
+
+    # Initialize the model.
+    if resume_checkpoint:
+        ckpt = torch.load(f=resume_checkpoint_path)
+        model.load_state_dict(state_dict=ckpt["model_state_dict"])
+        optimizer.load_state_dict(state_dict=ckpt["optimizer_state_dict"])
+        scheduler.load_state_dict(state_dict=ckpt["scheduler_state_dict"])
+        start_epoch = ckpt["epoch"]
+        print(f"model loaded from {resume_checkpoint_path}")
+    else:
+        start_epoch = 0
+
+    # Print the model hyperparameters.
+    print_params(batch_size=batch_size,
+                 checkpoints_folder=checkpoints_folder,
+                 learning_rate=learning_rate,
+                 learning_rate_decay=learning_rate_decay,
+                 log_csv=log_csv,
+                 num_epochs=num_epochs,
+                 num_layers=num_layers,
+                 pretrain=pretrain,
+                 resume_checkpoint=resume_checkpoint,
+                 resume_checkpoint_path=resume_checkpoint_path,
+                 save_interval=save_interval,
+                 train_folder=train_folder,
+                 weight_decay=weight_decay)
+
+    # Logging the model after every epoch.
+    # Confirm the output directory exists.
+    log_csv.parent.mkdir(parents=True, exist_ok=True)
+
+    with log_csv.open(mode="w") as writer:
+        writer.write("epoch,train_loss,train_acc,val_loss,val_acc\n")
+        # Train the model.
+        train_helper(model=model,
+                     dataloaders=dataloaders,
+                     dataset_sizes=dataset_sizes,
+                     criterion=nn.CrossEntropyLoss(),
+                     optimizer=optimizer,
+                     scheduler=scheduler,
+                     start_epoch=start_epoch,
+                     writer=writer,
+                     batch_size=batch_size,
+                     checkpoints_folder=checkpoints_folder,
+                     device=device,
+                     num_layers=num_layers,
+                     save_interval=save_interval,
+                     num_epochs=num_epochs,
+                     classes=classes,
+                     num_classes=num_classes)
+
+
+###########################################
+#      MAIN EVALUATION FUNCTION           #
+###########################################
+
+
+def parse_val_acc(model_path: Path) -> float:
+    """
+    Parse the validation accuracy from the filename.
+
+    Args:
+        model_path: The model path to parse for the validation accuracy.
+
+    Returns:
+        The parsed validation accuracy.
+    """
+    return float(
+        f"{('.'.join(model_path.name.split('.')[:-1])).split('_')[-1][2:]}")
+
+
+def get_best_model(checkpoints_folder: Path) -> str:
+    """
+    Finds the model with the best validation accuracy.
+
+    Args:
+        checkpoints_folder: Folder containing the models to test.
+
+    Returns:
+        The location of the model with the best validation accuracy.
+    """
+    return max({
+        model: parse_val_acc(model_path=model)
+        for model in [m for m in checkpoints_folder.rglob("*.pt") if ".DS_Store" not in str(m)]
+    }.items(),
+               key=operator.itemgetter(1))[0]
+
+
+def get_predictions(patches_eval_folder: Path, output_folder: Path,
+                    checkpoints_folder: Path, auto_select: bool,
+                    eval_model: Path, device: torch.device, classes: List[str],
+                    num_classes: int, path_mean: List[float],
+                    path_std: List[float], num_layers: int, pretrain: bool,
+                    batch_size: int, num_workers: int) -> None:
+    """
+    Main function for running the model on all of the generated patches.
+
+    Args:
+        patches_eval_folder: Folder containing patches to evaluate on.
+        output_folder: Folder to save the model results to.
+        checkpoints_folder: Directory to save model checkpoints to.
+        auto_select: Automatically select the model with the highest validation accuracy,
+        eval_model: Path to the model with the highest validation accuracy.
+        device: Device to use for running model.
+        classes: Names of the classes in the dataset.
+        num_classes: Number of classes in the dataset.
+        path_mean: Means of the WSIs for each dimension.
+        path_std: Standard deviations of the WSIs for each dimension.
+        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
+        pretrain: Use pretrained ResNet weights.
+        batch_size: Mini-batch size to use for training.
+        num_workers: Number of workers to use for IO.
+    """
+    # Initialize the model.
+    model_path = get_best_model(
+        checkpoints_folder=checkpoints_folder) if auto_select else eval_model
+
+    model = create_model(num_classes=num_classes,
+                         num_layers=num_layers,
+                         pretrain=pretrain)
+    ckpt = torch.load(f=model_path)
+    model.load_state_dict(state_dict=ckpt["model_state_dict"])
+    model = model.to(device=device)
+
+    model.train(mode=False)
+    print(f"model loaded from {model_path}")
+
+    # For outputting the predictions.
+    class_num_to_class = {i: classes[i] for i in range(num_classes)}
+
+    start = time.time()
+    # Load the data for each folder.
+    image_folders = get_subfolder_paths(folder=patches_eval_folder)
+
+    # Where we want to write out the predictions.
+    # Confirm the output directory exists.
+    output_folder.mkdir(parents=True, exist_ok=True)
+
+    # For each WSI.
+    for image_folder in image_folders:
+
+        # Temporary fix. Need not to make folders with no crops.
+        try:
+            # Load the image dataset.
+            dataloader = torch.utils.data.DataLoader(
+                dataset=datasets.ImageFolder(
+                    root=str(image_folder),
+                    transform=transforms.Compose(transforms=[
+                        transforms.ToTensor(),
+                        transforms.Normalize(mean=path_mean, std=path_std)
+                    ])),
+                batch_size=batch_size,
+                shuffle=False,
+                num_workers=num_workers)
+        except RuntimeError:
+            print(
+                "WARNING: One of the image directories is empty. Skipping this directory."
+            )
+            continue
+
+        num_test_image_windows = len(dataloader) * batch_size
+
+        # Load the image names so we know the coordinates of the patches we are predicting.
+        image_folder = image_folder.joinpath(image_folder.name)
+        window_names = get_image_paths(folder=image_folder)
+
+        print(f"testing on {num_test_image_windows} crops from {image_folder}")
+
+        with output_folder.joinpath(f"{image_folder.name}.csv").open(
+                mode="w") as writer:
+
+            writer.write("x,y,prediction,confidence\n")
+
+            # Loop through all of the patches.
+            for batch_num, (test_inputs, test_labels) in enumerate(dataloader):
+                batch_window_names = window_names[batch_num *
+                                                  batch_size:batch_num *
+                                                  batch_size + batch_size]
+
+                confidences, test_preds = torch.max(nn.Softmax(dim=1)(model(
+                    test_inputs.to(device=device))),
+                                                    dim=1)
+                for i in range(test_preds.shape[0]):
+                    # Find coordinates and predicted class.
+                    xy = batch_window_names[i].name.split(".")[0].split(";")
+
+                    writer.write(
+                        f"{','.join([xy[0], xy[1], f'{class_num_to_class[test_preds[i].data.item()]}', f'{confidences[i].data.item():.5f}'])}\n"
+                    )
+
+    print(f"time for {patches_eval_folder}: {time.time() - start:.2f} seconds")