Diff of /bme1312/models.py [000000] .. [2147a4]

Switch to side-by-side view

--- a
+++ b/bme1312/models.py
@@ -0,0 +1,45 @@
+import torch
+from torch import nn
+
+from bme1301.utils import image2kspace, kspace2image, pseudo2real, pseudo2complex, complex2pseudo
+
+
+class DataConsistencyLayer(nn.Module):
+    """
+    This class support different types k-space data consistency
+    """
+
+    def __init__(self, is_data_fidelity=False):
+        super().__init__()
+        self.is_data_fidelity = is_data_fidelity
+        if is_data_fidelity:
+            self.data_fidelity = nn.Parameter(torch.randn(1))
+
+    def data_consistency(self, k, k0, mask):
+        """
+        :param k: input k-space (reconstructed kspace, 2D-Fourier transform of im)
+        :param k0: initially sampled k-space
+        :param mask: sampling pattern
+        """
+        if self.is_data_fidelity:
+            v = self.is_data_fidelity
+            k_dc = (1 - mask) * k + mask * (k + v * k0 / (1 + v))
+        else:
+            k_dc = (1 - mask) * k + mask * k0
+        return k_dc
+
+    def forward(self, im, k0, mask):
+        """
+        im   - Image in pseudo-complex [B, C=2, H, W]
+        k0   - original under-sampled Kspace in pseudo-complex [B, C=2, H, W]
+        mask - mask for Kspace in Real [B, H, W]
+        """
+        # mask need to add one axis to broadcast to pseudo-complex channel
+        k = image2kspace(pseudo2complex(im))  # [B, H, W] Complex
+        k0 = pseudo2complex(k0)
+        k_dc = self.data_consistency(k, k0, mask)  # [B, H, W] Complex
+        im_dc = complex2pseudo(kspace2image(k_dc))  # [B, C=2, H, W]
+
+        return im_dc
+
+