--- a +++ b/code/utils_split.py @@ -0,0 +1,137 @@ +""" +DeepSlide +Splits the data into training, validation, and testing sets. + +Authors: Jason Wei, Behnaz Abdollahi, Saeed Hassanpour +""" + +import shutil +from pathlib import Path +from typing import (Dict, List) + +from utils import (get_image_paths, get_subfolder_paths) + + +def split(keep_orig_copy: bool, wsi_train: Path, wsi_val: Path, wsi_test: Path, + classes: List[str], all_wsi: Path, val_wsi_per_class: int, + test_wsi_per_class: int, labels_train: Path, labels_test: Path, + labels_val: Path) -> None: + """ + Main function for splitting data. Note that we want the + validation and test sets to be balanced. + + Args: + keep_orig_copy: Whether to move or copy the WSI when splitting into training, validation, and test sets. + wsi_train: Location to be created to store WSI for training. + wsi_val: Location to be created to store WSI for validation. + wsi_test: Location to be created to store WSI for testing. + classes: Names of the classes in the dataset. + all_wsi: Location of the WSI organized in subfolders by class. + val_wsi_per_class: Number of WSI per class to use in the validation set. + test_wsi_per_class: Number of WSI per class to use in the test set. + labels_train: Location to store the CSV file labels for training. + labels_test: Location to store the CSV file labels for testing. + labels_val: Location to store the CSV file labels for validation. + """ + # Based on whether we want to move or keep the files. + head = shutil.copyfile if keep_orig_copy else shutil.move + + # Create folders. + for f in (wsi_train, wsi_val, wsi_test): + subfolders = [f.joinpath(_class) for _class in classes] + + for subfolder in subfolders: + # Confirm the output directory exists. + subfolder.mkdir(parents=True, exist_ok=True) + + train_img_to_label = {} + val_img_to_label = {} + test_img_to_label = {} + + def move_set(folder: Path, image_files: List[Path], + ops: shutil) -> Dict[Path, str]: + """ + Moves the sets to the desired output directories. + + Args: + folder: Folder to move images to. + image_files: Image files to move. + ops: Whether to move or copy the files. + + Return: + A dictionary mapping image filenames to classes. + """ + def remove_topdir(filepath: Path) -> Path: + """ + Remove the top directory since the filepath needs to be + a relative path (i.e., a/b/c.jpg -> b/c.jpg). + + Args: + filepath: Path to remove top directory from. + + Returns: + Path with top directory removed. + """ + return Path(*filepath.parts[1:]) + + img_to_label = {} + for image_file in image_files: + # Copy or move the files. + ops(src=image_file, + dst=folder.joinpath(remove_topdir(filepath=image_file))) + + img_to_label[Path(image_file.name)] = image_file.parent.name + + return img_to_label + + # Sort the images and move/copy them appropriately. + subfolder_paths = get_subfolder_paths(folder=all_wsi) + for subfolder in subfolder_paths: + image_paths = get_image_paths(folder=subfolder) + + # Make sure we have enough slides in each class. + assert len( + image_paths + ) > val_wsi_per_class + test_wsi_per_class, "Not enough slides in each class." + + # Assign training, test, and validation images. + test_idx = len(image_paths) - test_wsi_per_class + val_idx = test_idx - val_wsi_per_class + train_images = image_paths[:val_idx] + val_images = image_paths[val_idx:test_idx] + test_images = image_paths[test_idx:] + print(f"class {Path(subfolder).name} " + f"#train={len(train_images)} " + f"#val={len(val_images)} " + f"#test={len(test_images)}") + + # Move the training images. + train_img_to_label.update( + move_set(folder=wsi_train, image_files=train_images, ops=head)) + + # Move the validation images. + val_img_to_label.update( + move_set(folder=wsi_val, image_files=val_images, ops=head)) + + # Move the testing images. + test_img_to_label.update( + move_set(folder=wsi_test, image_files=test_images, ops=head)) + + def write_to_csv(dest_filename: Path, + image_label_dict: Dict[Path, str]) -> None: + """ + Write the image names and corresponding labels to a CSV file. + + Args: + dest_filename: Destination filename for the CSV file. + image_label_dict: Dictionary mapping filenames to labels. + """ + with dest_filename.open(mode="w") as writer: + writer.write("img,gt\n") + for img in sorted(image_label_dict.keys()): + writer.write(f"{img},{image_label_dict[img]}\n") + + write_to_csv(dest_filename=labels_train, + image_label_dict=train_img_to_label) + write_to_csv(dest_filename=labels_val, image_label_dict=val_img_to_label) + write_to_csv(dest_filename=labels_test, image_label_dict=test_img_to_label)