[96354c]: / tests / dataset / augmentations / test_spatial_augmentations.py

Download this file

43 lines (34 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import pytest
from src.dataset.augmentations.spatial_augmentations import RandomRotation90
from src.dataset.utils.visualization import plot_3_view
from tests.dataset.patching.common import load_patient, get_brain_mask
@pytest.fixture(scope="function")
def volume():
return np.array([[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.]]
])
@pytest.fixture(scope="function")
def patient():
brain, seg = load_patient()
return brain, seg, get_brain_mask()
def test_random_rotation_90(volume):
rot = RandomRotation90(p=1)
img = np.expand_dims(volume, axis=0)
img, seg, brain_mask = rot.__call__(img_and_mask=(img, volume, volume))
assert img.shape == img.shape
def test_random_rotation_90_real_patient(patient):
volume, seg, brain_mask = patient
rot = RandomRotation90(p=1)
rot_volume, rot_seg, _ = rot.__call__(img_and_mask=(volume, seg, brain_mask))
plot_3_view("rotated_volume", rot_volume[0, :, :, :], 100, save=True)
plot_3_view("rotated_seg", rot_seg[:, :, :], 100, save=True)
plot_3_view("volume", volume[0, :, :, :], 100, save=True)
plot_3_view("seg", seg[:, :, :], 100, save=True)