--- a +++ b/bme1312/models.py @@ -0,0 +1,45 @@ +import torch +from torch import nn + +from bme1301.utils import image2kspace, kspace2image, pseudo2real, pseudo2complex, complex2pseudo + + +class DataConsistencyLayer(nn.Module): + """ + This class support different types k-space data consistency + """ + + def __init__(self, is_data_fidelity=False): + super().__init__() + self.is_data_fidelity = is_data_fidelity + if is_data_fidelity: + self.data_fidelity = nn.Parameter(torch.randn(1)) + + def data_consistency(self, k, k0, mask): + """ + :param k: input k-space (reconstructed kspace, 2D-Fourier transform of im) + :param k0: initially sampled k-space + :param mask: sampling pattern + """ + if self.is_data_fidelity: + v = self.is_data_fidelity + k_dc = (1 - mask) * k + mask * (k + v * k0 / (1 + v)) + else: + k_dc = (1 - mask) * k + mask * k0 + return k_dc + + def forward(self, im, k0, mask): + """ + im - Image in pseudo-complex [B, C=2, H, W] + k0 - original under-sampled Kspace in pseudo-complex [B, C=2, H, W] + mask - mask for Kspace in Real [B, H, W] + """ + # mask need to add one axis to broadcast to pseudo-complex channel + k = image2kspace(pseudo2complex(im)) # [B, H, W] Complex + k0 = pseudo2complex(k0) + k_dc = self.data_consistency(k, k0, mask) # [B, H, W] Complex + im_dc = complex2pseudo(kspace2image(k_dc)) # [B, C=2, H, W] + + return im_dc + +