Diff of /code/utils_split.py [000000] .. [594161]

Switch to unified view

a b/code/utils_split.py
1
"""
2
DeepSlide
3
Splits the data into training, validation, and testing sets.
4
5
Authors: Jason Wei, Behnaz Abdollahi, Saeed Hassanpour
6
"""
7
8
import shutil
9
from pathlib import Path
10
from typing import (Dict, List)
11
12
from utils import (get_image_paths, get_subfolder_paths)
13
14
15
def split(keep_orig_copy: bool, wsi_train: Path, wsi_val: Path, wsi_test: Path,
16
          classes: List[str], all_wsi: Path, val_wsi_per_class: int,
17
          test_wsi_per_class: int, labels_train: Path, labels_test: Path,
18
          labels_val: Path) -> None:
19
    """
20
    Main function for splitting data. Note that we want the
21
    validation and test sets to be balanced.
22
23
    Args:
24
        keep_orig_copy: Whether to move or copy the WSI when splitting into training, validation, and test sets.
25
        wsi_train: Location to be created to store WSI for training.
26
        wsi_val: Location to be created to store WSI for validation.
27
        wsi_test: Location to be created to store WSI for testing.
28
        classes: Names of the classes in the dataset.
29
        all_wsi: Location of the WSI organized in subfolders by class.
30
        val_wsi_per_class: Number of WSI per class to use in the validation set.
31
        test_wsi_per_class: Number of WSI per class to use in the test set.
32
        labels_train: Location to store the CSV file labels for training.
33
        labels_test: Location to store the CSV file labels for testing.
34
        labels_val: Location to store the CSV file labels for validation.
35
    """
36
    # Based on whether we want to move or keep the files.
37
    head = shutil.copyfile if keep_orig_copy else shutil.move
38
39
    # Create folders.
40
    for f in (wsi_train, wsi_val, wsi_test):
41
        subfolders = [f.joinpath(_class) for _class in classes]
42
43
        for subfolder in subfolders:
44
            # Confirm the output directory exists.
45
            subfolder.mkdir(parents=True, exist_ok=True)
46
47
    train_img_to_label = {}
48
    val_img_to_label = {}
49
    test_img_to_label = {}
50
51
    def move_set(folder: Path, image_files: List[Path],
52
                 ops: shutil) -> Dict[Path, str]:
53
        """
54
        Moves the sets to the desired output directories.
55
56
        Args:
57
            folder: Folder to move images to.
58
            image_files: Image files to move.
59
            ops: Whether to move or copy the files.
60
61
        Return:
62
            A dictionary mapping image filenames to classes.
63
        """
64
        def remove_topdir(filepath: Path) -> Path:
65
            """
66
            Remove the top directory since the filepath needs to be
67
            a relative path (i.e., a/b/c.jpg -> b/c.jpg).
68
69
            Args:
70
                filepath: Path to remove top directory from.
71
72
            Returns:
73
                Path with top directory removed.
74
            """
75
            return Path(*filepath.parts[1:])
76
77
        img_to_label = {}
78
        for image_file in image_files:
79
            # Copy or move the files.
80
            ops(src=image_file,
81
                dst=folder.joinpath(remove_topdir(filepath=image_file)))
82
83
            img_to_label[Path(image_file.name)] = image_file.parent.name
84
85
        return img_to_label
86
87
    # Sort the images and move/copy them appropriately.
88
    subfolder_paths = get_subfolder_paths(folder=all_wsi)
89
    for subfolder in subfolder_paths:
90
        image_paths = get_image_paths(folder=subfolder)
91
92
        # Make sure we have enough slides in each class.
93
        assert len(
94
            image_paths
95
        ) > val_wsi_per_class + test_wsi_per_class, "Not enough slides in each class."
96
97
        # Assign training, test, and validation images.
98
        test_idx = len(image_paths) - test_wsi_per_class
99
        val_idx = test_idx - val_wsi_per_class
100
        train_images = image_paths[:val_idx]
101
        val_images = image_paths[val_idx:test_idx]
102
        test_images = image_paths[test_idx:]
103
        print(f"class {Path(subfolder).name} "
104
              f"#train={len(train_images)} "
105
              f"#val={len(val_images)} "
106
              f"#test={len(test_images)}")
107
108
        # Move the training images.
109
        train_img_to_label.update(
110
            move_set(folder=wsi_train, image_files=train_images, ops=head))
111
112
        # Move the validation images.
113
        val_img_to_label.update(
114
            move_set(folder=wsi_val, image_files=val_images, ops=head))
115
116
        # Move the testing images.
117
        test_img_to_label.update(
118
            move_set(folder=wsi_test, image_files=test_images, ops=head))
119
120
    def write_to_csv(dest_filename: Path,
121
                     image_label_dict: Dict[Path, str]) -> None:
122
        """
123
        Write the image names and corresponding labels to a CSV file.
124
125
        Args:
126
            dest_filename: Destination filename for the CSV file.
127
            image_label_dict: Dictionary mapping filenames to labels.
128
        """
129
        with dest_filename.open(mode="w") as writer:
130
            writer.write("img,gt\n")
131
            for img in sorted(image_label_dict.keys()):
132
                writer.write(f"{img},{image_label_dict[img]}\n")
133
134
    write_to_csv(dest_filename=labels_train,
135
                 image_label_dict=train_img_to_label)
136
    write_to_csv(dest_filename=labels_val, image_label_dict=val_img_to_label)
137
    write_to_csv(dest_filename=labels_test, image_label_dict=test_img_to_label)