Switch to unified view

a b/slideslicer/sample_from_slide.py
1
# coding: utf-8
2
3
from PIL import Image
4
import numpy as np
5
from collections import Counter
6
import pandas as pd
7
import os
8
import re
9
import json
10
from warnings import warn
11
12
import openslide
13
import cv2
14
from pycocotools.mask import encode, decode
15
16
from slideslicer.extract_rois_svs_xml import extract_rois_svs_xml
17
from slideslicer.slideutils import (plot_contour, get_median_color, 
18
                        get_thumbnail_magnification,
19
                        get_img_bbox, get_rotated_highres_roi,
20
                        get_uniform_tiles, 
21
                        get_threshold_tissue_mask, 
22
                        convert_contour2mask,
23
                        convert_mask2contour,
24
                        CropRotateRoi,
25
                        get_contour_centre, read_roi_patches_from_slide,
26
                        clip_roi_wi_bbox, sample_points)
27
28
29
def get_img_id(svsname):
30
    imgid = re.sub("\.svs$","", 
31
                   os.path.basename(svsname)
32
                   ).replace(" ", "_").replace("-","_")
33
    return imgid
34
35
def get_prefix(imgid, pos, name, tissueid, id, parentdir = "data", suffix=''):
36
    prefix = "{parentdir}/{typ}/{imgid}-{pos}-t{tissue}-r{roiid}-{typ}{suffix}".format(**{
37
                                        "tissue":tissueid,
38
                                        "pos": "x{}-y{}".format(*pos),
39
                                        "parentdir":parentdir,
40
                                        "imgid":imgid,
41
                                        "roiid":id,
42
                                        "typ": (name.replace(" ","_")),
43
                                        "suffix":suffix,
44
                                        })
45
    return prefix
46
47
48
def summarize_rois_wi_patch(rois, bg_names = ["tissue"], frac_thr=16):
49
    names = []
50
    areas = []
51
    ids = []
52
    
53
    tissue_info = []
54
    for rr in rois:
55
        if rr['name'] in bg_names:
56
            tissue_info.append(rr)
57
            continue
58
        names.append(rr['name'])
59
        areas.append(rr['area'])
60
        ids.append(rr['id'])
61
#     assert (len(tissue_info)==1)
62
    tissue_id = "+".join(sorted(["%s"%tt['id'] for tt in tissue_info]))
63
    dfareas = (pd.DataFrame(dict(area=areas, name=names, id=ids))
64
                     .sort_values("area", ascending=False)
65
               )
66
    areasum = (dfareas.groupby('name')
67
                     .agg({"area":sum, "id": "first"})
68
                     .sort_values("area", ascending=False)
69
              )
70
    if len(areasum) == 0:
71
        return {'name':'blank', 
72
            "id": tissue_id,
73
            "tissue_id": tissue_id,
74
            "stats": dfareas.to_dict(orient='records')
75
            }
76
    elif len(areasum)==1:
77
        name = areasum.index[0]
78
        id = areasum["id"][0]
79
    elif areasum["area"][0]/areasum["area"][1] > frac_thr:
80
        name = areasum.index[0]
81
        id = areasum["id"][0]
82
    else:
83
        name = '+'.join(areasum.index.tolist())
84
        id = '+'.join(sorted(areasum["id"].astype(str).tolist()))
85
    return {"name":name,
86
            "id": str(id),
87
            "tissue_id": tissue_id,
88
            "stats": dfareas.to_dict(orient='records')}
89
90
91
# Rewrite for generator if needed:
92
def visualise_chunks_and_rois(img_arr, roi_cropped_list,
93
                              nrows = 5, figsize=(15,15)
94
                             ):
95
    fig, axs = plt.subplots(nrows,len(img_arr)//nrows, figsize=figsize)
96
    for ax, reg, rois in zip(axs.ravel(), img_arr, roi_cropped_list):
97
        ax.imshow(reg)
98
        for rr in rois:
99
            if rr['name'] == 'tissue':
100
                continue
101
            plot_contour(rr["vertices"], ax=ax)
102
        xlab = "\n".join(["{}: {}".format(rr['id'], rr['name']) \
103
                          for rr in rois if rr['name'] !='tissue'])
104
        ax.set_xlabel(xlab)
105
        ax.set_xticklabels([])
106
        ax.set_yticklabels([])
107
        
108
109
def get_tissue_rois(slide,
110
                    roilist,
111
                    vis = False,
112
                    step = 1024,
113
                    magnlevel = 0,
114
                    target_size = None,
115
                    maxarea = 1e7,
116
                    random=False,
117
                    normal_only=True,
118
                    shift_factor = 2, 
119
                   ):
120
121
    print("NORMAL_ONLY", normal_only)
122
    if target_size is None:
123
        target_size = [step]*2
124
125
    tissue_rois = [roi for roi in roilist if roi['name']=='tissue']
126
127
    for roi in tissue_rois:
128
        print("tissue roi, id", roi["id"])
129
        cont = roi["vertices"]
130
        points = sample_points(cont,
131
                              spacing = step,
132
                              shift = -step//shift_factor,
133
                              mode = 'random' if random else 'grid')
134
135
        print("roi {} #{}:\t{:d} points sampled".format(roi["name"], roi["id"],len(points), ))
136
        pointroilist = [{"vertices":[pp], "area":0} for pp in points]
137
        
138
#         img_arr, roi_cropped_list, msk_arr, = \
139
        imgroiiter = read_roi_patches_from_slide(slide, 
140
                                        pointroilist,
141
                                        but_list = roilist,
142
                                        target_size = target_size,
143
                                        magnlevel = magnlevel,
144
                                        maxarea = maxarea,
145
                                        color=1,
146
                                        nchannels=3,
147
                                        allcomponents = True,
148
                                        nomask=True,
149
                                       )
150
#         if vis:
151
#             plt.scatter(points[:,0], points[:,1],c='r')
152
#             plot_contour(cont)
153
        # filter for rois with only normal tissue 
154
        def filter_(x):
155
            return all(roi['name']=='tissue' for roi in x[1])
156
        if normal_only:
157
            imgroiiter = filter(filter_, imgroiiter)
158
        yield imgroiiter
159
160
161
def save_tissue_chunks(imgroiiter, imgid, parentdir="data",
162
                       lower = [0, 0, 180],
163
                       upper = [179, 10, 255],
164
                       close=50,
165
                       open_=30,
166
                       filtersize = 20,
167
                       frac_thr=16,
168
                       ):
169
    for ii, (reg, rois, _, start_xy) in enumerate(imgroiiter):
170
        sumdict = summarize_rois_wi_patch(rois, bg_names = [], frac_thr=frac_thr)
171
        prefix = get_prefix(imgid, start_xy, sumdict["name"], sumdict["id"], ii,
172
                            parentdir=parentdir,)
173
174
        #fn_summary_json = prefix + "-summary.json"
175
        fn_json = prefix + ".json"
176
        fnoutpng = prefix + '.png'
177
        print(fnoutpng)
178
179
        os.makedirs(os.path.dirname(fn_json), exist_ok=True)
180
        #with open(fn_summary_json, 'w+') as fhj: json.dump(sumdict, fhj)
181
        if isinstance(reg, Image.Image):
182
            reg.save(fnoutpng)
183
        else:
184
            Image.fromarray(reg).save(fnoutpng)
185
186
        rois = add_roi_bytes(rois, np.asarray(reg),
187
                lower=lower, upper=upper,
188
                open=open_, close=close,
189
                filtersize=filtersize)
190
        with open(fn_json, 'w+') as fhj: json.dump(rois, fhj)
191
192
193
def add_roi_bytes(rois, reg,
194
                  lower = [0, 0, 180],
195
                  upper = [179, 25, 255],
196
                  filtersize=25,
197
                  close=True,
198
                  open=False,
199
                  minlen = -1):
200
    if minlen==-1:
201
        minlen=filtersize
202
    rois = rois.copy()
203
    tissue_roi = None
204
    other_mask_ = 0
205
    
206
    print('ROIS:', *[roi_['name'] for roi_ in rois])
207
    for roi_ in rois:
208
        if roi_["name"] == "tissue":
209
            tissue_roi = roi_
210
            continue
211
        mask_ = convert_contour2mask(roi_["vertices"], 
212
                                     reg.shape[1], reg.shape[0],
213
                                     fill=1, order='F')
214
215
        cocomask = encode(np.asarray(mask_, dtype='uint8'))
216
        cocomask["counts"] = cocomask["counts"].decode('utf-8')
217
        roi_.update(cocomask)
218
        if isinstance(roi_["vertices"], np.ndarray):
219
            roi_["vertices"] = roi_["vertices"].tolist()
220
        other_mask_ = np.maximum(other_mask_, mask_)
221
    
222
    roi_ = tissue_roi
223
    if roi_ is None:
224
        warn("Someting strange is going on. Make sure no tissue chunks are missing")
225
        roi_ = {'vertices': []}
226
    #print('tissue roi', roi_)
227
    if reg is not None:
228
        mask_ = get_threshold_tissue_mask(reg, color=True,
229
                                filtersize=filtersize,
230
                                dtype=bool,
231
                                open=open, close=close,
232
                                lower = lower, upper = upper)
233
        if mask_.sum()==0:
234
            roi_["vertices"]= []
235
            print("skipping empty mask", roi_['name'], roi_['id'])
236
        verts = convert_mask2contour(mask_.astype('uint8'), minlen=minlen)
237
        # print("verts", len(verts))
238
        if len(verts)>0:
239
            #print('vertices', verts[np.argmax(map(len,verts))])
240
            roi_["vertices"] = verts[np.argmax(map(len,verts))]
241
        else:
242
            #print("verts", len(verts), roi_["vertices"])
243
            pass
244
        mask_ = np.asarray(mask_, order='F')
245
    else:
246
        mask_ = convert_contour2mask(roi_["vertices"], reg.shape[1], reg.shape[0], 
247
                             fill=1, order='F')
248
        if mask_.sum()==0:
249
            roi_["vertices"]= []
250
            #continue
251
252
    if isinstance(other_mask_, np.ndarray):
253
        mask_ = mask_.astype(bool) & ~other_mask_.astype(bool)
254
    cocomask = encode(np.asarray(mask_, dtype='uint8'))
255
    cocomask["counts"] = cocomask["counts"].decode('utf-8')
256
    roi_.update(cocomask)
257
    if isinstance(roi_["vertices"], np.ndarray):
258
        roi_["vertices"] = roi_["vertices"].tolist()   
259
    rois = [rr for rr in rois if len(rr['vertices'])>0]
260
    return rois
261
262
263
if __name__ == '__main__':
264
    import sys
265
    import argparse
266
267
    parser = argparse.ArgumentParser()
268
    parser.add_argument(
269
      '--data-root',
270
      type=str,
271
      default='../data',
272
      help='The directory where the input data will be stored.')
273
274
    parser.add_argument(
275
      '--json-dir',
276
      type=str,
277
      default='../data/roi-json',
278
      help='The directory where the roi JSON files will be stored.')
279
280
    parser.add_argument(
281
      '--keep-empty',
282
      action='store_true',
283
      default=False,
284
      help='keep empty tissue chunks (with no annotations within)')
285
286
    parser.add_argument(
287
      '--target-side',
288
      type=int,
289
      default=1024,
290
      help='The directory where the input data will be stored.')
291
292
    parser.add_argument(
293
      '--max-area',
294
      type=float,
295
      default=1e7,
296
      help='maximal area of a roi')
297
298
    parser.add_argument(
299
      '--fnxml',
300
      dest='fnxml',
301
      type=str,
302
      help='The XML files for ROI.')
303
304
    parser.add_argument(
305
      '--all-grid',
306
      action='store_true',
307
      default=False,
308
      help='store all grid patches (by defaut grid patches that overlap features will be removed)')
309
310
    parser.add_argument(
311
      '--target-sampling',
312
      action='store_true',
313
      default=False,
314
      help='store only grid patches')
315
316
    parser.add_argument(
317
      '--keep-levels',
318
      type=int,
319
      default=3,
320
      help='.')
321
322
    parser.add_argument(
323
      '--magnlevel',
324
      type=int,
325
      default=0,
326
      help='.')
327
328
    parser.add_argument(
329
      '--frac-stride',
330
      type=int,
331
      default=1,
332
      help='.')
333
334
    prms = parser.parse_args()
335
    VISUALIZE = False
336
337
    lower = [0, 0, 180]
338
    upper = [179, 10, 255]
339
    close=50
340
    open_=30
341
    filtersize = 20
342
343
    fnsvs = re.sub(".xml$", ".svs", prms.fnxml)
344
345
    outdir = os.path.join(prms.data_root, "data_{}/fullsplit".format(prms.target_side))
346
347
    ## setup
348
    imgid = get_img_id(fnsvs)
349
350
    target_size = [prms.target_side, prms.target_side,]
351
    #os.makedirs(outdir)
352
353
    # ## Read XML ROI, convert, and save as JSON
354
    fnjson = extract_rois_svs_xml(prms.fnxml, outdir=prms.json_dir,
355
                                  remove_empty = ~prms.keep_empty,
356
                                  keeplevels=prms.keep_levels)
357
358
    with open(fnjson,'r') as fh:
359
        roilist = json.load(fh)
360
361
    print("ROI type counts")
362
    print(pd.Series([roi["name"] for roi in roilist]).value_counts())
363
364
    # read slide
365
    slide = openslide.OpenSlide(fnsvs)
366
367
    # load the thumbnail image
368
    img = np.asarray(slide.associated_images["thumbnail"])
369
370
    median_color = get_median_color(slide)
371
    ratio = get_thumbnail_magnification(slide)
372
373
    print("full scale slide dimensions: w={}, h={}".format(*slide.dimensions))
374
375
    if VISUALIZE:
376
        from matplotlib import pyplot as plt
377
        colordict = {'open glom': 'b',
378
                     'scler glom': 'm',
379
                     'infl':'r',
380
                     'tissue':'w',
381
                     'other tissue':'y',
382
                     'art':'olive',
383
                     'fold':'y'}
384
385
        #cell#
386
387
        plt.figure(figsize = (18,10))
388
        plt.imshow(img)
389
        for roi in roilist:
390
            plot_contour(roi["vertices"]/ratio, c=colordict[roi['name']])
391
392
        #cell#
393
        vert = roilist[19]["vertices"]
394
        target_size = [1024]*2
395
        x,y,w,h = cv2.boundingRect(np.asarray(vert).round().astype(int))
396
        mask, cropped_vertices = get_region_mask(vert, [x,y], (w,h), color=(255,))
397
398
        plt.imshow(mask)
399
        plot_contour(cropped_vertices, c='r')
400
        print(mask.max())
401
402
    #############################
403
    if prms.target_sampling:
404
        print("READING TARGETED ROIS", file=sys.stderr)
405
406
        imgroiiter = read_roi_patches_from_slide(slide, roilist,
407
                                target_size = target_size,
408
                                maxarea = prms.max_area,
409
                                nchannels=3,
410
                                allcomponents=True,
411
                               )
412
413
        print("READING AND SAVING SMALLER ROIS (GLOMERULI, INFLAMMATION LOCI ETC.)",
414
              file=sys.stderr) 
415
416
        for reg, rois,_, start_xy in imgroiiter:
417
            sumdict = summarize_rois_wi_patch(rois, bg_names = ["tissue"], frac_thr=16)
418
            prefix = get_prefix(imgid, start_xy, sumdict["name"], sumdict["tissue_id"],
419
                                sumdict["id"], parentdir=outdir, suffix='-targeted')
420
            #fn_summary_json = prefix + "-summary.json"
421
            fn_json = prefix + ".json"
422
            fnoutpng = prefix + '.png'
423
            print(fnoutpng)
424
            os.makedirs(os.path.dirname(fn_json), exist_ok=True)
425
            
426
            #with open(fn_summary_json, 'w+') as fhj: json.dump(sumdict, fhj)
427
            if isinstance(reg, Image.Image):
428
                reg.save(fnoutpng)
429
            else:
430
                Image.fromarray(reg).save(fnoutpng)
431
            
432
            rois = add_roi_bytes(rois, reg, lower=lower, upper=upper,
433
                                 close=close,
434
                                 open=open_,
435
                                 filtersize = filtersize)
436
            with open(fn_json, 'w+') as fhj: json.dump( rois, fhj)
437
438
    print("READING AND SAVING _FEATURELESS_ / NORMAL TISSUE", file=sys.stderr)
439
440
    magnification = 4**prms.magnlevel
441
    real_side = prms.target_side * magnification
442
443
    for tissue_chunk_iter in get_tissue_rois(slide,
444
                                            roilist,
445
                                            vis = False,
446
                                            step = real_side // prms.frac_stride,
447
                                            target_size = [real_side]*2,
448
                                            maxarea = 1e7,
449
                                            random=False,
450
                                            normal_only = not prms.all_grid,
451
                                           ):
452
            # save
453
            print('saving tissue chunk')
454
            save_tissue_chunks(tissue_chunk_iter, imgid, parentdir=outdir,
455
                               close=close,
456
                               open_=open_,
457
                               frac_thr=16,
458
                               filtersize = filtersize)