Diff of /bme1312/utils.py [000000] .. [2147a4]

Switch to unified view

a b/bme1312/utils.py
1
import math
2
import io
3
4
import numpy as np
5
import torch
6
import torchvision.utils
7
from matplotlib import pyplot as plt
8
from matplotlib.colors import Normalize
9
from mpl_toolkits.axes_grid1 import make_axes_locatable
10
from skimage.color import rgb2gray
11
from skimage.metrics import structural_similarity
12
13
14
def plot_loss(loss):
15
    plt.figure()
16
    plt.plot(loss)
17
    plt.show()
18
    plt.close('all')
19
20
21
def imgshow(im, cmap=None, rgb_axis=None, dpi=100, figsize=(6.4, 4.8)):
22
    if isinstance(im, torch.Tensor):
23
        im = im.to('cpu').detach().cpu().numpy()
24
    if rgb_axis is not None:
25
        im = np.moveaxis(im, rgb_axis, -1)
26
        im = rgb2gray(im)
27
28
    plt.figure(dpi=dpi, figsize=figsize)
29
    norm_obj = Normalize(vmin=im.min(), vmax=im.max())
30
    plt.imshow(im, norm=norm_obj, cmap=cmap)
31
    plt.colorbar()
32
    plt.show()
33
    plt.close('all')
34
35
36
def imsshow(imgs, titles=None, num_col=5, dpi=100, cmap=None, is_colorbar=False, is_ticks=False):
37
    '''
38
    assume imgs's shape is (Nslice, Nx, Ny)
39
    '''
40
    num_imgs = len(imgs)
41
    num_row = math.ceil(num_imgs / num_col)
42
    fig_width = num_col * 3
43
    if is_colorbar:
44
        fig_width += num_col * 1.5
45
    fig_height = num_row * 3
46
    fig = plt.figure(dpi=dpi, figsize=(fig_width, fig_height))
47
    for i in range(num_imgs):
48
        ax = plt.subplot(num_row, num_col, i + 1)
49
        im = ax.imshow(imgs[i], cmap=cmap)
50
        if titles:
51
            plt.title(titles[i])
52
        if is_colorbar:
53
            cax = fig.add_axes([ax.get_position().x1 + 0.01, ax.get_position().y0, 0.01, ax.get_position().height])
54
            plt.colorbar(im, cax=cax)
55
        if not is_ticks:
56
            ax.set_xticks([])
57
            ax.set_yticks([])
58
    plt.show()
59
    plt.close('all')
60
61
62
def image_mask_overlay(image, mask) -> np.ndarray:
63
    """
64
    :param image: [H, W] float(0~1) or uint8(0~255)
65
    :param mask: [H, W] int64
66
    :return: [H, W, C]
67
    """
68
69
    def _fig2numpy(fig, dpi) -> np.ndarray:
70
        """
71
        Convert matplotlib figure to numpy array
72
        """
73
        io_buf = io.BytesIO()
74
        fig.savefig(io_buf, format='raw', dpi=dpi)
75
        io_buf.seek(0)
76
        img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
77
                            newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
78
        io_buf.close()
79
        return img_arr
80
81
82
    H, W = image.shape
83
    dpi = H
84
    # dpi = dpi * factor
85
    fig = plt.figure(figsize=(math.ceil(H / dpi), math.ceil(W / dpi)), dpi=dpi)
86
    plt.xticks([])
87
    plt.yticks([])
88
    ax = fig.subplots(1, 1)
89
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
90
91
    ax.imshow(image, cmap='gray', interpolation='nearest')
92
    ax.imshow(mask, cmap='jet', alpha=0.5)
93
    ax.axis('off')
94
    im = _fig2numpy(fig, dpi=dpi)
95
    plt.close(fig)
96
    return im
97
98
99
def make_grid_and_show(ims, nrow=5, cmap=None):
100
    if isinstance(ims, np.ndarray):
101
        ims = torch.from_numpy(ims)
102
103
    B, C, H, W = ims.shape
104
    grid_im = torchvision.utils.make_grid(ims, nrow=nrow)
105
    fig_h, fig_w = nrow * 2 + 1, (B / nrow) + 1
106
    imgshow(grid_im, cmap=cmap, rgb_axis=0, dpi=200, figsize=(fig_h, fig_w))
107
108
109
def int2preetyStr(num: int):
110
    s = str(num)
111
    remain_len = len(s)
112
    while remain_len - 3 > 0:
113
        s = s[:remain_len - 3] + ',' + s[remain_len - 3:]
114
        remain_len -= 3
115
    return s
116
117
118
def compute_num_params(module, is_trace=False):
119
    print(int2preetyStr(sum([p.numel() for p in module.parameters()])))
120
    if is_trace:
121
        for item in [f"[{int2preetyStr(info[1].numel())}] {info[0]}:{tuple(info[1].shape)}"
122
                     for info in module.named_parameters()]:
123
            print(item)
124
125
126
def tonp(x):
127
    if isinstance(x, torch.Tensor):
128
        return x.detach().cpu()
129
    else:
130
        return x
131
132
133
def pseudo2real(x):
134
    """
135
    :param x: [..., C=2, H, W]
136
    :return: [..., H, W]
137
    """
138
    return (x[..., 0, :, :] ** 2 + x[..., 1, :, :] ** 2) ** 0.5
139
140
141
def complex2pseudo(x):
142
    """
143
    :param x: [..., H, W] Complex
144
    :return: [...., C=2, H, W]
145
    """
146
    if isinstance(x, np.ndarray):
147
        return np.stack([x.real, x.imag], axis=-3)
148
    elif isinstance(x, torch.Tensor):
149
        return torch.stack([x.real, x.imag], dim=-3)
150
    else:
151
        raise RuntimeError("Unsupported type.")
152
153
154
def pseudo2complex(x):
155
    """
156
    :param x:  [..., C=2, H, W]
157
    :return: [..., H, W] Complex
158
    """
159
    return x[..., 0, :, :] + x[..., 1, :, :] * 1j
160
161
162
# ================================
163
# Preprocessing
164
# ================================
165
def minmax_normalize(x, eps=1e-8):
166
    min = x.min()
167
    max = x.max()
168
    return (x - min) / (max - min + eps)
169
170
171
# ================================
172
# kspace and image domain transform
173
# reference: [ismrmrd-python-tools/transform.py at master · ismrmrd/ismrmrd-python-tools · GitHub](https://github.com/ismrmrd/ismrmrd-python-tools/blob/master/ismrmrdtools/transform.py)
174
# ================================
175
def image2kspace(x):
176
    if isinstance(x, np.ndarray):
177
        x = np.fft.ifftshift(x, axes=(-2, -1))
178
        x = np.fft.fft2(x)
179
        x = np.fft.fftshift(x, axes=(-2, -1))
180
        return x
181
    elif isinstance(x, torch.Tensor):
182
        x = torch.fft.ifftshift(x, dim=(-2, -1))
183
        x = torch.fft.fft2(x)
184
        x = torch.fft.fftshift(x, dim=(-2, -1))
185
        return x
186
    else:
187
        raise RuntimeError("Unsupported type.")
188
189
190
def kspace2image(x):
191
    if isinstance(x, np.ndarray):
192
        x = np.fft.ifftshift(x, axes=(-2, -1))
193
        x = np.fft.ifft2(x)
194
        x = np.fft.fftshift(x, axes=(-2, -1))
195
        return x
196
    elif isinstance(x, torch.Tensor):
197
        x = torch.fft.ifftshift(x, dim=(-2, -1))
198
        x = torch.fft.ifft2(x)
199
        x = torch.fft.fftshift(x, dim=(-2, -1))
200
        return x
201
    else:
202
        raise RuntimeError("Unsupported type.")
203
204
205
# ======================================
206
# Metrics
207
# ======================================
208
def compute_mse(x, y):
209
    """
210
    REQUIREMENT: `x` and `y` can be any shape, but their shape have to be same
211
    """
212
    assert x.dtype == y.dtype and x.shape == y.shape, \
213
        'x and y is not compatible to compute MSE metric'
214
215
    if isinstance(x, np.ndarray):
216
        mse = np.mean(np.abs(x - y) ** 2)
217
218
    elif isinstance(x, torch.Tensor):
219
        mse = torch.mean(torch.abs(x - y) ** 2)
220
221
    else:
222
        raise RuntimeError(
223
            'Unsupported object type'
224
        )
225
    return mse
226
227
228
def compute_psnr(reconstructed_im, target_im, peak='normalized', is_minmax=False):
229
    '''
230
    Image must be of either Integer [0, 255] or Float value [0,1]
231
    :param peak: 'max' or 'normalize', max_intensity will be the maximum value of target_im if peek == 'max.
232
          when peek is 'normalized', max_intensity will be the maximum value depend on data representation (in this
233
          case, we assume your input should be normalized to [0,1])
234
    REQUIREMENT: `x` and `y` can be any shape, but their shape have to be same
235
    '''
236
    assert target_im.dtype == reconstructed_im.dtype and target_im.shape == reconstructed_im.shape, \
237
        'target_im and reconstructed_im is not compatible to compute PSNR metric'
238
    assert peak in {'max', 'normalized'}, \
239
        'peak mode is not supported'
240
241
    eps = 1e-8  # to avoid math error in log(x) when x=0
242
243
    if is_minmax:
244
        reconstructed_im = minmax_normalize(reconstructed_im, eps)
245
        target_im = minmax_normalize(target_im, eps)
246
247
    if isinstance(target_im, np.ndarray):
248
        max_intensity = 255 if target_im.dtype == np.uint8 else 1.0
249
        max_intensity = np.max(target_im).item() if peak == 'max' else max_intensity
250
        psnr = 20 * math.log10(max_intensity) - 10 * np.log10(compute_mse(reconstructed_im, target_im) + eps)
251
252
    elif isinstance(target_im, torch.Tensor):
253
        max_intensity = 255 if target_im.dtype == torch.uint8 else 1.0
254
        max_intensity = torch.max(target_im).item() if peak == 'max' else max_intensity
255
        psnr = 20 * math.log10(max_intensity) - 10 * torch.log10(compute_mse(reconstructed_im, target_im) + eps)
256
257
    else:
258
        raise RuntimeError(
259
            'Unsupported object type'
260
        )
261
    return psnr
262
263
264
def compute_ssim(reconstructed_im, target_im, is_minmax=False):
265
    """
266
    Compute structural similarity index between two batches using skimage library,
267
    which only accept 2D-image input. We have to specify where is image's axes.
268
269
    WARNING: this method using skimage's implementation, DOES NOT SUPPORT GRADIENT
270
    """
271
    assert target_im.dtype == reconstructed_im.dtype and target_im.shape == reconstructed_im.shape, \
272
        'target_im and reconstructed_im is not compatible to compute SSIM metric'
273
274
    if isinstance(target_im, np.ndarray):
275
        pass
276
    elif isinstance(target_im, torch.Tensor):
277
        target_im = target_im.detach().to('cpu').numpy()
278
        reconstructed_im = reconstructed_im.detach().to('cpu').numpy()
279
    else:
280
        raise RuntimeError(
281
            'Unsupported object type'
282
        )
283
    
284
    eps = 1e-8  # to avoid math error in log(x) when x=0
285
286
    if is_minmax:
287
        reconstructed_im = minmax_normalize(reconstructed_im, eps)
288
        target_im = minmax_normalize(target_im, eps)
289
    
290
    ssim_value = structural_similarity(target_im, reconstructed_im, \
291
        gaussian_weights=True, sigma=1.5, use_sample_covariance=False)
292
293
    return ssim_value