[7823dd]: / pathaia / datasets / functional_api.py

Download this file

783 lines (603 with data), 21.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
# coding: utf8
"""
A module to implement useful functions to apply to dataset.
I still don't knwo exactly what we are putting into this module.
"""
from typing import (
Sequence, Dict, Any, Callable, Generator, Union, Iterable
)
from ..util.types import RefDataSet, SplitDataSet, DataSet
import numpy as np
from .errors import (
InvalidDatasetError,
InvalidSplitError,
TagNotFoundError
)
import openslide
from .data import fast_slide_query
def extend_to_split_datasets(processing: Callable) -> Callable:
"""
Decorate a dataset processing to extend usage to split datasets.
Args:
processing: a function that takes a RefDataSet and return a RefDataSet.
Returns:
Function adapted to Dataset inputs.
"""
def extended_version(
dataset: DataSet, *args, **kwargs
) -> Union[Dict, DataSet]:
"""
Wrap the processing in this function.
Args:
dataset: just a dataset.
Returns:
shuffled version of the data generator.
"""
if isinstance(dataset, tuple):
return processing(dataset, *args, **kwargs)
if isinstance(dataset, dict):
result = dict()
for set_name, set_data in dataset.items():
result[set_name] = processing(set_data, *args, **kwargs)
return result
raise InvalidDatasetError(
"{} is not a valid type for datasets!"
" It should be a {}...".format(type(dataset), DataSet)
)
return extended_version
@extend_to_split_datasets
def info(dataset: RefDataSet) -> Dict:
"""
Produce info on an unsplitted dataset.
Args:
dataset: samples of a dataset.
Returns:
Unique labels in the dataset with associated population.
"""
x, y = dataset
info = dict()
for tag in y:
if tag not in info:
info[tag] = 1
else:
info[tag] += 1
return info
@extend_to_split_datasets
def ratio_info(dataset: RefDataSet) -> Dict:
"""
Produce ratios info on an unsplitted dataset.
Args:
dataset: samples of a dataset.
Returns:
Unique labels in the dataset with associated population.
"""
x, y = dataset
populations = dict()
result = dict()
for tag in y:
if tag not in populations:
populations[tag] = 1
else:
populations[tag] += 1
for tag, population in populations.items():
result[tag] = float(population) / len(y)
return result
@extend_to_split_datasets
def class_data(dataset: RefDataSet, class_name: Union[str, int]) -> Dict:
"""
Produce info on an unsplitted dataset.
Args:
dataset: samples of a dataset.
Returns:
Unique labels in the dataset with associated population.
"""
x, y = dataset
res_x = []
res_y = []
if class_name in y:
for spl, tag in zip(x, y):
if tag == class_name:
res_x.append(spl)
res_y.append(tag)
return res_x, res_y
raise TagNotFoundError(
"Tag '{}' is not in dataset {}!".format(
class_name, info(dataset)
)
)
@extend_to_split_datasets
def shuffle_dataset(dataset: RefDataSet) -> RefDataSet:
"""
Shuffle samples in a dataset.
Args:
dataset: samples of a dataset.
Returns:
Shuffled dataset.
"""
x, y = dataset
ridx = np.arange(len(y))
np.random.shuffle(ridx)
rx = [x[i] for i in ridx]
ry = [y[i] for i in ridx]
return rx, ry
@extend_to_split_datasets
def clean_dataset(
dataset: RefDataSet, dtype: type, rm: Sequence[Any]
) -> RefDataSet:
"""
Remove bad data from a reference dataset.
Args:
dataset: samples of a dataset.
dtype: type of data to keep.
rm: sequence of labels to remove from the dataset.
Returns:
Purified dataset.
"""
x, y = dataset
pure_x = []
pure_y = []
for spl_x, spl_y in zip(x, y):
if isinstance(spl_y, dtype) and spl_y not in rm:
pure_x.append(spl_x)
pure_y.append(spl_y)
return pure_x, pure_y
def balance_cat(dataset: RefDataSet, cat: Any, lack: int) -> RefDataSet:
"""
Compensate lack of a category in a dataset by random sample duplication.
Args:
dataset: samples of a dataset.
cat: label in the dataset to enrich.
missing: missing samples in the dataset to reach expected population.
Returns:
Balanced category.
"""
x, y = dataset
cat_x = [spl for spl, lab in zip(x, y) if lab == cat]
ridx = np.arange(len(cat_x))
np.random.shuffle(ridx)
x_padding = [cat_x[ridx[k % len(ridx)]] for k in range(lack)]
y_padding = [cat for k in range(lack)]
return x_padding, y_padding
@extend_to_split_datasets
def balance_dataset(dataset: RefDataSet) -> RefDataSet:
"""
Balance the dataset using the balance_cat function on each cat.
Args:
dataset: samples of a dataset.
Returns:
The balanced dataset.
"""
x = [xd for xd in dataset[0]]
y = [yd for yd in dataset[1]]
cat_count = info(dataset)
try:
maj_count = max(cat_count.values())
for cat, count in cat_count.items():
lack = maj_count - count
if lack > 0:
x_pad, y_pad = balance_cat(dataset, cat, lack)
x += x_pad
y += y_pad
return x, y
except ValueError as e:
raise InvalidDatasetError(
"{} check your dataset: {}".format(e, cat_count)
)
@extend_to_split_datasets
def fair_dataset(
dataset: RefDataSet, dtype: type, rm: Sequence[Any]
) -> RefDataSet:
"""
Make a dataset fair.
Purify, balance and shuffle a dataset.
Args:
dataset: samples of a dataset.
dtype: type of data to keep.
rm: sequence of labels to remove from the dataset.
Returns:
Fair dataset.
"""
return shuffle_dataset(balance_dataset(clean_dataset(dataset, dtype, rm)))
@extend_to_split_datasets
def clip_dataset(dataset: RefDataSet, max_spl: int) -> RefDataSet:
"""
Clip a dataset (to a max number of samples).
Args:
dataset: samples of a dataset.
max_spl: max number of samples.
Returns:
Clipped dataset.
"""
x, y = dataset
mx = min(max_spl, len(dataset[0]))
return x[0:mx], y[0:mx]
def split_dataset(
dataset: RefDataSet,
sections: Sequence,
preserve_ratio: bool = True
) -> SplitDataSet:
"""
Compute split of the dataset from ratios.
Args:
dataset: samples of a dataset.
sections: ratios of different splits, should sum to 1.
Returns:
splits of the dataset.
"""
x, y = dataset
ratios = ratio_info(dataset)
population = info(dataset)
result = dict()
if isinstance(sections, dict):
if sum(sections.values()) == 1:
offsets = {k: 0 for k in ratios.keys()}
for set_name, set_ratio in sections.items():
x_set = []
y_set = []
for class_name in offsets.keys():
offset = offsets[class_name]
class_size = population[class_name]
class_set_size = int(set_ratio * class_size)
cx, cy = class_data(dataset, class_name)
cx_set = cx[offset:offset + class_set_size]
cy_set = cy[offset:offset + class_set_size]
x_set += cx_set
y_set += cy_set
offsets[class_name] += class_set_size
result[set_name] = (x_set, y_set)
return result
raise InvalidSplitError(
"Split values provided do not sum to 1: {}".format(sections)
)
if isinstance(sections, list) or isinstance(sections, tuple):
if sum(sections) == 1:
offsets = {k: 0 for k in ratios.keys()}
for set_name, set_ratio in enumerate(sections):
x_set = []
y_set = []
for class_name in offsets.keys():
offset = offsets[class_name]
class_size = population[class_name]
class_set_size = int(set_ratio * class_size)
cx, cy = class_data(dataset, class_name)
cx_set = cx[offset:offset + class_set_size]
cy_set = cy[offset:offset + class_set_size]
x_set += cx_set
y_set += cy_set
offsets[class_name] += class_set_size
result[set_name] = (x_set, y_set)
return result
raise InvalidSplitError(
"Split values provided do not sum to 1: {}".format(sections)
)
raise InvalidSplitError(
"Invalid arguments provided to the split method: \n{}\n{}".format(
sections, info(dataset)
)
)
# Decorators on dataset generators
# Careful here, since above functions are used as pre-processing steps,
# (called before the wrapped function)
# the calling order of the decorators is reversed:
# ---------
# @clean -|
# @balance -|-----> @be_fair
# @shuffle -|
# @clip
# @split
# @batch
# def my_generator(dataset):
# x, y = dataset
# for sx, sy in zip(x, y):
# yield sx, sy
# -------------------------
# will first shuffle, then clip the dataset...
def pre_shuffle(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the shuffle function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
shuffle the dataset before the data_generator is applied.
"""
def shuffled_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
shuffled version of the data generator.
"""
new_dataset = shuffle_dataset(dataset)
return data_generator(new_dataset)
return shuffled_version
def pre_balance(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the balance function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
balance the dataset before the data_generator is applied.
"""
def balanced_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
balanced version of the data generator.
"""
new_dataset = balance_dataset(dataset)
return data_generator(new_dataset)
return balanced_version
def pre_split(sections: Sequence) -> Callable:
"""Parameterize the decorator."""
def decorator(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the clip function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
split the dataset before the data_generator is applied.
"""
def split_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
split version of the data generator.
"""
new_dataset = split_dataset(dataset, sections)
@extend_to_split_datasets
def gen(ds):
return data_generator(ds)
return gen(new_dataset)
return split_version
return decorator
def pre_clip(max_spl: int) -> Callable:
"""Parameterize the decorator."""
def decorator(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the clip function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
clip the dataset before the data_generator is applied.
"""
def clipped_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
clipped version of the data generator.
"""
new_dataset = clip_dataset(dataset, max_spl)
return data_generator(new_dataset)
return clipped_version
return decorator
def pre_batch(batch_size: int, keep_last: bool = False) -> Callable:
"""Parameterize the decorator."""
def decorator(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the batch function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
batch the dataset before the data_generator is applied.
"""
def batched_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
batch version of the data generator.
"""
xb = []
yb = []
gen = data_generator(dataset)
for x, y in gen:
if len(xb) == batch_size:
xb = []
yb = []
xb.append(x)
yb.append(y)
if len(xb) == batch_size:
yield xb, yb
if len(xb) > 0 and len(xb) < batch_size and keep_last:
yield xb, yb
return batched_version
return decorator
def pre_clean(dtype: type, rm: Sequence[Any]) -> Callable:
"""Parameterize the decorator."""
def decorator(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the clean function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
clean the dataset before the data_generator is applied.
"""
def cleaned_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
cleaned version of the data generator.
"""
new_dataset = clean_dataset(dataset, dtype, rm)
return data_generator(new_dataset)
return cleaned_version
return decorator
def pre_be_fair(dtype: type, rm: Sequence[Any]) -> Callable:
"""Parameterize the decorator."""
def decorator(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the clean function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
clean the dataset before the data_generator is applied.
"""
def fair_version(dataset: DataSet) -> Iterable:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
cleaned version of the data generator.
"""
new_dataset = fair_dataset(dataset, dtype, rm)
return data_generator(new_dataset)
return fair_version
return decorator
def post_shuffle(dataset_creator: Callable) -> Callable:
"""
Decorate a dataset creator function with the shuffle function.
Args:
dataset_creator: a function that takes any arguments and returns a dataset.
Returns:
shuffle the dataset after creation.
"""
def shuffled_version(*args, **kwargs) -> RefDataSet:
"""
Wrap the dataset creator in this function.
Returns:
shuffled version of the dataset creator.
"""
new_dataset = dataset_creator(*args, **kwargs)
return shuffle_dataset(new_dataset)
return shuffled_version
def post_balance(dataset_creator: Callable) -> Callable:
"""
Decorate a dataset creator function with the balance function.
Args:
dataset_creator: a function that takes any arguments and returns a dataset.
Returns:
balance the dataset after creation.
"""
def balanced_version(*args, **kwargs) -> RefDataSet:
"""
Wrap the dataset_creator in this function.
Returns:
balanced version of the dataset creator.
"""
new_dataset = dataset_creator(*args, **kwargs)
return balance_dataset(new_dataset)
return balanced_version
def post_split(sections: Sequence) -> Callable:
"""Parameterize the decorator."""
def decorator(dataset_creator: Callable) -> Callable:
"""
Decorate a dataset creator function with the clip function.
Args:
dataset_creator: a function that takes any arguments and returns a dataset.
Returns:
split the dataset before the data_generator is applied.
"""
def split_version(*args, **kwargs) -> SplitDataSet:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
split version of the data generator.
"""
new_dataset = dataset_creator(*args, **kwargs)
return split_dataset(new_dataset, sections)
return split_version
return decorator
def post_clip(max_spl: int) -> Callable:
"""Parameterize the decorator."""
def decorator(dataset_creator: Callable) -> Callable:
"""
Decorate a dataset creator function with the clip function.
Args:
dataset_creator: a function that takes any arguments and returns a dataset.
Returns:
clip the dataset before the data_generator is applied.
"""
def clipped_version(*args, **kwargs) -> RefDataSet:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
clipped version of the data generator.
"""
new_dataset = dataset_creator(*args, **kwargs)
return clip_dataset(new_dataset, max_spl)
return clipped_version
return decorator
def post_clean(dtype: type, rm: Sequence[Any]) -> Callable:
"""Parameterize the decorator."""
def decorator(dataset_creator: Callable) -> Callable:
"""
Decorate a dataset creator function with the clean function.
Args:
dataset_creator: a function that takes any arguments and returns a dataset.
Returns:
clean the dataset before the data_generator is applied.
"""
def cleaned_version(*args, **kwargs) -> RefDataSet:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
cleaned version of the data generator.
"""
new_dataset = dataset_creator(*args, **kwargs)
return clean_dataset(new_dataset, dtype, rm)
return cleaned_version
return decorator
def post_be_fair(dtype: type, rm: Sequence[Any]) -> Callable:
"""Parameterize the decorator."""
def decorator(dataset_creator: Callable) -> Callable:
"""
Decorate a dataset creator function with the fair function.
Args:
dataset_creator: a function that takes any arguments and returns a dataset.
Returns:
clean the dataset before the data_generator is applied.
"""
def fair_version(*args, **kwargs) -> RefDataSet:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
cleaned version of the data generator.
"""
new_dataset = dataset_creator(*args, **kwargs)
return fair_dataset(new_dataset, dtype, rm)
return fair_version
return decorator
def query_slide(
slides: Dict[str, openslide.OpenSlide],
patch_size: int
) -> Callable:
"""Parameterize the decorator."""
def decorator(data_generator: Callable) -> Callable:
"""
Decorate a data generator function with the clean function.
Args:
data_generator: a function that takes a dataset and yield samples.
Returns:
clean the dataset before the data_generator is applied.
"""
def query_version(dataset: DataSet) -> Generator:
"""
Wrap the data_generator in this function.
Args:
dataset: just a dataset.
Returns:
cleaned version of the data generator.
"""
for x, y in data_generator(dataset):
yield fast_slide_query(slides, x, patch_size), y
return query_version
return decorator