[96354c]: / src / dataset / patching / equal_label_distribution.py

Download this file

12 lines (10 with data), 504 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import numpy as np
from typing import Tuple
from src.dataset.patching.commons import select_patch_by_label_distribution
def patching(volume: np.ndarray, labels: np.ndarray, patch_size: tuple, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Patches with equal probability from each label
"""
identity_function = lambda label: label
volume_patch, seg_patch = select_patch_by_label_distribution(volume, labels, patch_size, identity_function)
return volume_patch, seg_patch