--- a
+++ b/augm/lambda_channel.py
@@ -0,0 +1,37 @@
+import torch
+import torchio
+from torchio import TYPE, DATA, Subject
+from torchio.transforms import Lambda
+
+
+class LambdaChannel(Lambda):
+    def apply_transform(self, sample: Subject) -> dict:
+        for image in sample.get_images(intensity_only=False):
+
+            image_type = image[TYPE]
+            if self.types_to_apply is not None:
+                if image_type not in self.types_to_apply:
+                    continue
+
+            function_arg = image[DATA]
+            result = self.function(function_arg)
+            if not isinstance(result, torch.Tensor):
+                message = (
+                    'The returned value from the callable argument must be'
+                    f' of type {torch.Tensor}, not {type(result)}'
+                )
+                raise ValueError(message)
+            if result.dtype != torch.float32:
+                message = (
+                    'The data type of the returned value must be'
+                    f' of type {torch.float32}, not {result.dtype}'
+                )
+                raise ValueError(message)
+            if result.ndim != function_arg.ndim:
+                message = (
+                    'The number of dimensions of the returned value must'
+                    f' be {function_arg.ndim}, not {result.ndim}'
+                )
+                raise ValueError(message)
+            image[DATA] = result
+        return sample