[cc8b8f]: / augm / lambda_channel.py

Download this file

38 lines (33 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torchio
from torchio import TYPE, DATA, Subject
from torchio.transforms import Lambda
class LambdaChannel(Lambda):
def apply_transform(self, sample: Subject) -> dict:
for image in sample.get_images(intensity_only=False):
image_type = image[TYPE]
if self.types_to_apply is not None:
if image_type not in self.types_to_apply:
continue
function_arg = image[DATA]
result = self.function(function_arg)
if not isinstance(result, torch.Tensor):
message = (
'The returned value from the callable argument must be'
f' of type {torch.Tensor}, not {type(result)}'
)
raise ValueError(message)
if result.dtype != torch.float32:
message = (
'The data type of the returned value must be'
f' of type {torch.float32}, not {result.dtype}'
)
raise ValueError(message)
if result.ndim != function_arg.ndim:
message = (
'The number of dimensions of the returned value must'
f' be {function_arg.ndim}, not {result.ndim}'
)
raise ValueError(message)
image[DATA] = result
return sample