a b/data_tools.py
1
import os
2
import subprocess
3
from distutils import dir_util
4
from pathlib import Path
5
6
import cv2 as cv
7
import numpy as np
8
from PIL import Image
9
from matplotlib import pyplot as plt
10
from tqdm import tqdm
11
12
from features import NucleiFeatures
13
14
Image.MAX_IMAGE_PIXELS = None
15
16
17
def get_x_and_y(name):
18
    x, y = os.path.splitext(name)[0].split('_')[-2:]
19
    return int(x), int(y)
20
21
22
def split_image(img, x_tiles_cnt=None, y_tiles_cnt=None, x_tile_size=None, y_tile_size=None, base='img'):
23
    """
24
    Splits an image array to smaller tiles for further segmentation.
25
    Specify tiles count OR tiles size.
26
    X axis means the arr.shape[1] coordinate, be careful!
27
    Tile names are used to restore the initial image after segmentation.
28
29
    Parameters
30
    ----------
31
    img : numpy ndarray
32
        The input image.
33
    x_tiles_cnt : integer
34
        Number of tiles along the x axis of img.
35
    y_tiles_cnt : integer
36
        Number of tiles along the y axis of img.
37
    x_tile_size : integer
38
        Size of tile along x axis.
39
    y_tile_size : integer
40
        Size of tile along y axis.
41
    base : str
42
        Base for tile names.
43
44
    Returns
45
    -------
46
    tiles : list
47
        List of tiles.
48
    tile_names : list
49
        List of tile names.
50
    """
51
    if (x_tile_size is not None) and (y_tile_size is not None):
52
        x_tiles_cnt = img.shape[1] // x_tile_size
53
        y_tiles_cnt = img.shape[0] // y_tile_size
54
55
    if (x_tiles_cnt is not None) and (y_tiles_cnt is not None):
56
        x_ticks = np.linspace(0, img.shape[1], x_tiles_cnt + 1).astype(int)
57
        y_ticks = np.linspace(0, img.shape[0], y_tiles_cnt + 1).astype(int)
58
59
    else:
60
        raise Exception('Specify tiles count OR tiles size.')
61
62
    tiles = []
63
    tile_names = []
64
65
    for x_num, x in enumerate(zip(x_ticks[:-1], x_ticks[1:])):
66
        for y_num, y in enumerate(zip(y_ticks[:-1], y_ticks[1:])):
67
            tiles.append(img[y[0]:y[1], x[0]:x[1]])
68
            tile_names.append(f'{base}_{x_num}_{y_num}')
69
    return tiles, tile_names
70
71
72
def prepare_test_data(tiles, tile_names, base_dir, force=False):
73
    """
74
    Saves data in the proper way.
75
76
    Parameters
77
    ----------
78
    tiles : list
79
        List of tiles.
80
    tile_names : list
81
        List of tile names.
82
    base_dir : str
83
        Full path to base directory, 'full/path/../data_test' in normal case.
84
    force : bool
85
        Rewrite existing files
86
87
    Returns
88
    -------
89
    None
90
    """
91
92
    base_dir = Path(base_dir)
93
    if not os.path.exists(base_dir):
94
        os.makedirs(base_dir)
95
96
    if not force and len(os.listdir(base_dir)) > 0:
97
        raise ValueError(f'base_dir {base_dir} is not empty, use force=True option if you want to rewrite files')
98
    elif len(os.listdir(base_dir)):
99
        dir_util.remove_tree(str(base_dir))
100
        os.makedirs(base_dir)
101
102
    for tile, name in zip(tiles, tile_names):
103
        if tile.max() <= 1:
104
            tile = (tile * 255).astype(np.uint8)
105
        os.mkdir(base_dir / name)
106
        os.mkdir(base_dir / name / 'images')
107
        cv.imwrite(str(base_dir / name / 'images' / f'{name}.png'), tile)
108
109
110
def restore_image(work_dir, tiff=False):
111
    """
112
    Restores the initial image.
113
114
        Parameters
115
    ----------
116
    work_dir : str
117
        Full path to directory with files.
118
    tiff : bool
119
        Is the target image a multilayer tiff or not
120
121
    Returns
122
    -------
123
    img : numpy ndarray
124
        Initial image
125
126
    """
127
128
    work_dir = Path(work_dir)
129
130
    file_names = sorted(os.listdir(work_dir), key=lambda x: get_x_and_y(x)[::-1])
131
    coords = np.array([get_x_and_y(n) for n in file_names])
132
    x_max, y_max = coords.max(axis=0)
133
134
    if tiff:
135
        tiles = {}
136
        max_number = int(0)
137
        for n in file_names:
138
            tmp = cv.imread(str(work_dir / n), -1)
139
            tmp = (tmp + max_number) * (tmp > 0)
140
            max_number = tmp.max()
141
            tiles[get_x_and_y(n)] = tmp.copy()
142
    else:
143
        tiles = {get_x_and_y(n): cv.imread(str(work_dir / n), -1) for n in file_names}
144
145
    long_tiles = []
146
147
    for y in range(y_max + 1):
148
        long_tiles.append([])
149
        for x in range(x_max + 1):
150
            long_tiles[-1].append(tiles[(x, y)])
151
152
    long_tiles = [np.hstack(i) for i in long_tiles]
153
154
    return np.vstack(long_tiles)
155
156
157
def perform_segmentation(full_img_path, sample_dir, network_dir, force=False, features=None):
158
    network_dir = Path(network_dir)
159
    try:
160
        full_img = cv.imread(full_img_path, -1)
161
    except cv.error:
162
        full_img = plt.imread(full_img_path)
163
    tiles, tile_names = split_image(img=full_img, x_tile_size=1000, y_tile_size=1000)
164
    prepare_test_data(tiles, tile_names, sample_dir, force=force)
165
166
    try:
167
        dir_util.remove_tree(str(network_dir / 'data_test'))
168
    except:
169
        pass
170
    os.mkdir(str(network_dir / 'data_test'))
171
    dir_util.copy_tree(sample_dir, str(network_dir / 'data_test'));
172
173
    try:
174
        dir_util.remove_tree(str(network_dir / 'predictions'))
175
    except:
176
        pass
177
178
    try:
179
        dir_util.remove_tree(str(network_dir / 'albu/results_test'))
180
    except:
181
        pass
182
183
    result_dir = str(Path(sample_dir)) + '_segmented'
184
185
    subprocess.run(f"cd {network_dir} && bash 'predict_test.sh'", shell=True)
186
    dir_util.copy_tree(str(network_dir / 'predictions'), result_dir);
187
188
    if features is not None:
189
        NucleiFeatures(f'{result_dir}/lgbm_test_sub2', sample_dir,
190
                       features=features).df().to_csv(f'{result_dir}/{os.path.split(sample_dir)[1]}.csv',
191
                                                      index=False)
192
193
194
def color_tiff(img, n=60):
195
    img = img % n
196
    seg_color = np.zeros((*img.shape, 3), dtype=np.uint8)
197
    for i in tqdm(range(1, img.max() + 1)):
198
        seg_color[img == i] = np.random.randint(0, 255, 3)
199
    return seg_color