--- a +++ b/augm/lambda_channel.py @@ -0,0 +1,37 @@ +import torch +import torchio +from torchio import TYPE, DATA, Subject +from torchio.transforms import Lambda + + +class LambdaChannel(Lambda): + def apply_transform(self, sample: Subject) -> dict: + for image in sample.get_images(intensity_only=False): + + image_type = image[TYPE] + if self.types_to_apply is not None: + if image_type not in self.types_to_apply: + continue + + function_arg = image[DATA] + result = self.function(function_arg) + if not isinstance(result, torch.Tensor): + message = ( + 'The returned value from the callable argument must be' + f' of type {torch.Tensor}, not {type(result)}' + ) + raise ValueError(message) + if result.dtype != torch.float32: + message = ( + 'The data type of the returned value must be' + f' of type {torch.float32}, not {result.dtype}' + ) + raise ValueError(message) + if result.ndim != function_arg.ndim: + message = ( + 'The number of dimensions of the returned value must' + f' be {function_arg.ndim}, not {result.ndim}' + ) + raise ValueError(message) + image[DATA] = result + return sample