|
a |
|
b/fetal_net/postprocess.py |
|
|
1 |
from scipy.ndimage.measurements import label |
|
|
2 |
from scipy.ndimage.filters import gaussian_filter |
|
|
3 |
from scipy.ndimage.morphology import binary_fill_holes |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def get_main_connected_component(data): |
|
|
8 |
labeled_array, num_features = label(data) |
|
|
9 |
i = np.argmax([np.sum(labeled_array == _) for _ in range(1, num_features + 1)]) + 1 |
|
|
10 |
return labeled_array == i |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def postprocess_prediction(pred, gaussian_std=1, threshold=0.5, fill_holes=True, connected_component=True): |
|
|
14 |
pred = gaussian_filter(pred, gaussian_std) > threshold |
|
|
15 |
if fill_holes: |
|
|
16 |
pred = binary_fill_holes(pred) |
|
|
17 |
if connected_component: |
|
|
18 |
pred = get_main_connected_component(pred) |
|
|
19 |
return pred |