a b/test/test_predict.py
1
import nibabel as nib
2
import numpy as np
3
4
from unittest import TestCase
5
6
from fetal_net.utils.patches import compute_patch_indices, get_patch_from_3d_data, reconstruct_from_patches
7
8
9
class TestPrediction(TestCase):
10
    def setUp(self):
11
        image_shape = (120, 144, 90)
12
        data = np.arange(0, image_shape[0]*image_shape[1]*image_shape[2]).reshape(image_shape)
13
        affine = np.diag(np.ones(4))
14
        self.image = nib.Nifti1Image(data, affine)
15
16
    def test_reconstruct_from_patches(self):
17
        patch_shape = (32, 32, 32)
18
        patch_overlap = 0
19
        patch_indices = compute_patch_indices(self.image.shape, patch_shape, patch_overlap)
20
        patches = [get_patch_from_3d_data(self.image.get_data(), patch_shape, index) for index in patch_indices]
21
        reconstruced_data = reconstruct_from_patches(patches, patch_indices, self.image.shape)
22
        # noinspection PyTypeChecker
23
        self.assertTrue(np.all(self.image.get_data() == reconstruced_data))
24
25
    def test_reconstruct_with_overlapping_patches(self):
26
        patch_overlap = 0
27
        patch_shape = (32, 32, 32)
28
        patch_indices = compute_patch_indices(self.image.shape, patch_shape, patch_overlap)
29
        patches = [get_patch_from_3d_data(self.image.get_data(), patch_shape, index) for index in patch_indices]
30
        # extend patches with modified patches that are 2 lower than the original patches
31
        patches.extend([patch - 2 for patch in patches])
32
        patch_indices = np.concatenate([patch_indices, patch_indices], axis=0)
33
        reconstruced_data = reconstruct_from_patches(patches, patch_indices, self.image.shape)
34
        # The reconstructed data should be 1 lower than the original data as 2 was subtracted from half the patches.
35
        # The resulting reconstruction should be the average.
36
        # noinspection PyTypeChecker
37
        self.assertTrue(np.all((self.image.get_data() - 1) == reconstruced_data))
38
39
    def test_reconstruct_with_overlapping_patches2(self):
40
        image_shape = (144, 144, 144)
41
        data = np.arange(0, image_shape[0]*image_shape[1]*image_shape[2]).reshape(image_shape)
42
        patch_overlap = 16
43
        patch_shape = (64, 64, 64)
44
        patch_indices = compute_patch_indices(data.shape, patch_shape, patch_overlap)
45
        patches = [get_patch_from_3d_data(data, patch_shape, index) for index in patch_indices]
46
47
        no_overlap_indices = compute_patch_indices(data.shape, patch_shape, 32)
48
        patch_indices = np.concatenate([patch_indices, no_overlap_indices])
49
        patches.extend([get_patch_from_3d_data(data, patch_shape, index) for index in no_overlap_indices])
50
        reconstruced_data = reconstruct_from_patches(patches, patch_indices, data.shape)
51
        # noinspection PyTypeChecker
52
        self.assertTrue(np.all(data == reconstruced_data))
53
54
    def test_reconstruct_with_multiple_channels(self):
55
        image_shape = (144, 144, 144)
56
        n_channels = 4
57
        data = np.arange(0, image_shape[0]*image_shape[1]*image_shape[2]*n_channels).reshape(
58
            [n_channels] + list(image_shape))
59
        patch_overlap = 16
60
        patch_shape = (64, 64, 64)
61
        patch_indices = compute_patch_indices(image_shape, patch_shape, patch_overlap)
62
        patches = [get_patch_from_3d_data(data, patch_shape, index) for index in patch_indices]
63
        self.assertEqual(patches[0].shape, tuple([4] + list(patch_shape)))
64
65
        reconstruced_data = reconstruct_from_patches(patches, patch_indices, data.shape)
66
        # noinspection PyTypeChecker
67
        self.assertTrue(np.all(data == reconstruced_data))
68