|
a |
|
b/bme1312/utils.py |
|
|
1 |
import math |
|
|
2 |
import io |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
import torchvision.utils |
|
|
7 |
from matplotlib import pyplot as plt |
|
|
8 |
from matplotlib.colors import Normalize |
|
|
9 |
from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
|
10 |
from skimage.color import rgb2gray |
|
|
11 |
from skimage.metrics import structural_similarity |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def plot_loss(loss): |
|
|
15 |
plt.figure() |
|
|
16 |
plt.plot(loss) |
|
|
17 |
plt.show() |
|
|
18 |
plt.close('all') |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def imgshow(im, cmap=None, rgb_axis=None, dpi=100, figsize=(6.4, 4.8)): |
|
|
22 |
if isinstance(im, torch.Tensor): |
|
|
23 |
im = im.to('cpu').detach().cpu().numpy() |
|
|
24 |
if rgb_axis is not None: |
|
|
25 |
im = np.moveaxis(im, rgb_axis, -1) |
|
|
26 |
im = rgb2gray(im) |
|
|
27 |
|
|
|
28 |
plt.figure(dpi=dpi, figsize=figsize) |
|
|
29 |
norm_obj = Normalize(vmin=im.min(), vmax=im.max()) |
|
|
30 |
plt.imshow(im, norm=norm_obj, cmap=cmap) |
|
|
31 |
plt.colorbar() |
|
|
32 |
plt.show() |
|
|
33 |
plt.close('all') |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def imsshow(imgs, titles=None, num_col=5, dpi=100, cmap=None, is_colorbar=False, is_ticks=False): |
|
|
37 |
''' |
|
|
38 |
assume imgs's shape is (Nslice, Nx, Ny) |
|
|
39 |
''' |
|
|
40 |
num_imgs = len(imgs) |
|
|
41 |
num_row = math.ceil(num_imgs / num_col) |
|
|
42 |
fig_width = num_col * 3 |
|
|
43 |
if is_colorbar: |
|
|
44 |
fig_width += num_col * 1.5 |
|
|
45 |
fig_height = num_row * 3 |
|
|
46 |
fig = plt.figure(dpi=dpi, figsize=(fig_width, fig_height)) |
|
|
47 |
for i in range(num_imgs): |
|
|
48 |
ax = plt.subplot(num_row, num_col, i + 1) |
|
|
49 |
im = ax.imshow(imgs[i], cmap=cmap) |
|
|
50 |
if titles: |
|
|
51 |
plt.title(titles[i]) |
|
|
52 |
if is_colorbar: |
|
|
53 |
cax = fig.add_axes([ax.get_position().x1 + 0.01, ax.get_position().y0, 0.01, ax.get_position().height]) |
|
|
54 |
plt.colorbar(im, cax=cax) |
|
|
55 |
if not is_ticks: |
|
|
56 |
ax.set_xticks([]) |
|
|
57 |
ax.set_yticks([]) |
|
|
58 |
plt.show() |
|
|
59 |
plt.close('all') |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def image_mask_overlay(image, mask) -> np.ndarray: |
|
|
63 |
""" |
|
|
64 |
:param image: [H, W] float(0~1) or uint8(0~255) |
|
|
65 |
:param mask: [H, W] int64 |
|
|
66 |
:return: [H, W, C] |
|
|
67 |
""" |
|
|
68 |
|
|
|
69 |
def _fig2numpy(fig, dpi) -> np.ndarray: |
|
|
70 |
""" |
|
|
71 |
Convert matplotlib figure to numpy array |
|
|
72 |
""" |
|
|
73 |
io_buf = io.BytesIO() |
|
|
74 |
fig.savefig(io_buf, format='raw', dpi=dpi) |
|
|
75 |
io_buf.seek(0) |
|
|
76 |
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8), |
|
|
77 |
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)) |
|
|
78 |
io_buf.close() |
|
|
79 |
return img_arr |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
H, W = image.shape |
|
|
83 |
dpi = H |
|
|
84 |
# dpi = dpi * factor |
|
|
85 |
fig = plt.figure(figsize=(math.ceil(H / dpi), math.ceil(W / dpi)), dpi=dpi) |
|
|
86 |
plt.xticks([]) |
|
|
87 |
plt.yticks([]) |
|
|
88 |
ax = fig.subplots(1, 1) |
|
|
89 |
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) |
|
|
90 |
|
|
|
91 |
ax.imshow(image, cmap='gray', interpolation='nearest') |
|
|
92 |
ax.imshow(mask, cmap='jet', alpha=0.5) |
|
|
93 |
ax.axis('off') |
|
|
94 |
im = _fig2numpy(fig, dpi=dpi) |
|
|
95 |
plt.close(fig) |
|
|
96 |
return im |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def make_grid_and_show(ims, nrow=5, cmap=None): |
|
|
100 |
if isinstance(ims, np.ndarray): |
|
|
101 |
ims = torch.from_numpy(ims) |
|
|
102 |
|
|
|
103 |
B, C, H, W = ims.shape |
|
|
104 |
grid_im = torchvision.utils.make_grid(ims, nrow=nrow) |
|
|
105 |
fig_h, fig_w = nrow * 2 + 1, (B / nrow) + 1 |
|
|
106 |
imgshow(grid_im, cmap=cmap, rgb_axis=0, dpi=200, figsize=(fig_h, fig_w)) |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
def int2preetyStr(num: int): |
|
|
110 |
s = str(num) |
|
|
111 |
remain_len = len(s) |
|
|
112 |
while remain_len - 3 > 0: |
|
|
113 |
s = s[:remain_len - 3] + ',' + s[remain_len - 3:] |
|
|
114 |
remain_len -= 3 |
|
|
115 |
return s |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def compute_num_params(module, is_trace=False): |
|
|
119 |
print(int2preetyStr(sum([p.numel() for p in module.parameters()]))) |
|
|
120 |
if is_trace: |
|
|
121 |
for item in [f"[{int2preetyStr(info[1].numel())}] {info[0]}:{tuple(info[1].shape)}" |
|
|
122 |
for info in module.named_parameters()]: |
|
|
123 |
print(item) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def tonp(x): |
|
|
127 |
if isinstance(x, torch.Tensor): |
|
|
128 |
return x.detach().cpu() |
|
|
129 |
else: |
|
|
130 |
return x |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def pseudo2real(x): |
|
|
134 |
""" |
|
|
135 |
:param x: [..., C=2, H, W] |
|
|
136 |
:return: [..., H, W] |
|
|
137 |
""" |
|
|
138 |
return (x[..., 0, :, :] ** 2 + x[..., 1, :, :] ** 2) ** 0.5 |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
def complex2pseudo(x): |
|
|
142 |
""" |
|
|
143 |
:param x: [..., H, W] Complex |
|
|
144 |
:return: [...., C=2, H, W] |
|
|
145 |
""" |
|
|
146 |
if isinstance(x, np.ndarray): |
|
|
147 |
return np.stack([x.real, x.imag], axis=-3) |
|
|
148 |
elif isinstance(x, torch.Tensor): |
|
|
149 |
return torch.stack([x.real, x.imag], dim=-3) |
|
|
150 |
else: |
|
|
151 |
raise RuntimeError("Unsupported type.") |
|
|
152 |
|
|
|
153 |
|
|
|
154 |
def pseudo2complex(x): |
|
|
155 |
""" |
|
|
156 |
:param x: [..., C=2, H, W] |
|
|
157 |
:return: [..., H, W] Complex |
|
|
158 |
""" |
|
|
159 |
return x[..., 0, :, :] + x[..., 1, :, :] * 1j |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
# ================================ |
|
|
163 |
# Preprocessing |
|
|
164 |
# ================================ |
|
|
165 |
def minmax_normalize(x, eps=1e-8): |
|
|
166 |
min = x.min() |
|
|
167 |
max = x.max() |
|
|
168 |
return (x - min) / (max - min + eps) |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
# ================================ |
|
|
172 |
# kspace and image domain transform |
|
|
173 |
# reference: [ismrmrd-python-tools/transform.py at master · ismrmrd/ismrmrd-python-tools · GitHub](https://github.com/ismrmrd/ismrmrd-python-tools/blob/master/ismrmrdtools/transform.py) |
|
|
174 |
# ================================ |
|
|
175 |
def image2kspace(x): |
|
|
176 |
if isinstance(x, np.ndarray): |
|
|
177 |
x = np.fft.ifftshift(x, axes=(-2, -1)) |
|
|
178 |
x = np.fft.fft2(x) |
|
|
179 |
x = np.fft.fftshift(x, axes=(-2, -1)) |
|
|
180 |
return x |
|
|
181 |
elif isinstance(x, torch.Tensor): |
|
|
182 |
x = torch.fft.ifftshift(x, dim=(-2, -1)) |
|
|
183 |
x = torch.fft.fft2(x) |
|
|
184 |
x = torch.fft.fftshift(x, dim=(-2, -1)) |
|
|
185 |
return x |
|
|
186 |
else: |
|
|
187 |
raise RuntimeError("Unsupported type.") |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
def kspace2image(x): |
|
|
191 |
if isinstance(x, np.ndarray): |
|
|
192 |
x = np.fft.ifftshift(x, axes=(-2, -1)) |
|
|
193 |
x = np.fft.ifft2(x) |
|
|
194 |
x = np.fft.fftshift(x, axes=(-2, -1)) |
|
|
195 |
return x |
|
|
196 |
elif isinstance(x, torch.Tensor): |
|
|
197 |
x = torch.fft.ifftshift(x, dim=(-2, -1)) |
|
|
198 |
x = torch.fft.ifft2(x) |
|
|
199 |
x = torch.fft.fftshift(x, dim=(-2, -1)) |
|
|
200 |
return x |
|
|
201 |
else: |
|
|
202 |
raise RuntimeError("Unsupported type.") |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
# ====================================== |
|
|
206 |
# Metrics |
|
|
207 |
# ====================================== |
|
|
208 |
def compute_mse(x, y): |
|
|
209 |
""" |
|
|
210 |
REQUIREMENT: `x` and `y` can be any shape, but their shape have to be same |
|
|
211 |
""" |
|
|
212 |
assert x.dtype == y.dtype and x.shape == y.shape, \ |
|
|
213 |
'x and y is not compatible to compute MSE metric' |
|
|
214 |
|
|
|
215 |
if isinstance(x, np.ndarray): |
|
|
216 |
mse = np.mean(np.abs(x - y) ** 2) |
|
|
217 |
|
|
|
218 |
elif isinstance(x, torch.Tensor): |
|
|
219 |
mse = torch.mean(torch.abs(x - y) ** 2) |
|
|
220 |
|
|
|
221 |
else: |
|
|
222 |
raise RuntimeError( |
|
|
223 |
'Unsupported object type' |
|
|
224 |
) |
|
|
225 |
return mse |
|
|
226 |
|
|
|
227 |
|
|
|
228 |
def compute_psnr(reconstructed_im, target_im, peak='normalized', is_minmax=False): |
|
|
229 |
''' |
|
|
230 |
Image must be of either Integer [0, 255] or Float value [0,1] |
|
|
231 |
:param peak: 'max' or 'normalize', max_intensity will be the maximum value of target_im if peek == 'max. |
|
|
232 |
when peek is 'normalized', max_intensity will be the maximum value depend on data representation (in this |
|
|
233 |
case, we assume your input should be normalized to [0,1]) |
|
|
234 |
REQUIREMENT: `x` and `y` can be any shape, but their shape have to be same |
|
|
235 |
''' |
|
|
236 |
assert target_im.dtype == reconstructed_im.dtype and target_im.shape == reconstructed_im.shape, \ |
|
|
237 |
'target_im and reconstructed_im is not compatible to compute PSNR metric' |
|
|
238 |
assert peak in {'max', 'normalized'}, \ |
|
|
239 |
'peak mode is not supported' |
|
|
240 |
|
|
|
241 |
eps = 1e-8 # to avoid math error in log(x) when x=0 |
|
|
242 |
|
|
|
243 |
if is_minmax: |
|
|
244 |
reconstructed_im = minmax_normalize(reconstructed_im, eps) |
|
|
245 |
target_im = minmax_normalize(target_im, eps) |
|
|
246 |
|
|
|
247 |
if isinstance(target_im, np.ndarray): |
|
|
248 |
max_intensity = 255 if target_im.dtype == np.uint8 else 1.0 |
|
|
249 |
max_intensity = np.max(target_im).item() if peak == 'max' else max_intensity |
|
|
250 |
psnr = 20 * math.log10(max_intensity) - 10 * np.log10(compute_mse(reconstructed_im, target_im) + eps) |
|
|
251 |
|
|
|
252 |
elif isinstance(target_im, torch.Tensor): |
|
|
253 |
max_intensity = 255 if target_im.dtype == torch.uint8 else 1.0 |
|
|
254 |
max_intensity = torch.max(target_im).item() if peak == 'max' else max_intensity |
|
|
255 |
psnr = 20 * math.log10(max_intensity) - 10 * torch.log10(compute_mse(reconstructed_im, target_im) + eps) |
|
|
256 |
|
|
|
257 |
else: |
|
|
258 |
raise RuntimeError( |
|
|
259 |
'Unsupported object type' |
|
|
260 |
) |
|
|
261 |
return psnr |
|
|
262 |
|
|
|
263 |
|
|
|
264 |
def compute_ssim(reconstructed_im, target_im, is_minmax=False): |
|
|
265 |
""" |
|
|
266 |
Compute structural similarity index between two batches using skimage library, |
|
|
267 |
which only accept 2D-image input. We have to specify where is image's axes. |
|
|
268 |
|
|
|
269 |
WARNING: this method using skimage's implementation, DOES NOT SUPPORT GRADIENT |
|
|
270 |
""" |
|
|
271 |
assert target_im.dtype == reconstructed_im.dtype and target_im.shape == reconstructed_im.shape, \ |
|
|
272 |
'target_im and reconstructed_im is not compatible to compute SSIM metric' |
|
|
273 |
|
|
|
274 |
if isinstance(target_im, np.ndarray): |
|
|
275 |
pass |
|
|
276 |
elif isinstance(target_im, torch.Tensor): |
|
|
277 |
target_im = target_im.detach().to('cpu').numpy() |
|
|
278 |
reconstructed_im = reconstructed_im.detach().to('cpu').numpy() |
|
|
279 |
else: |
|
|
280 |
raise RuntimeError( |
|
|
281 |
'Unsupported object type' |
|
|
282 |
) |
|
|
283 |
|
|
|
284 |
eps = 1e-8 # to avoid math error in log(x) when x=0 |
|
|
285 |
|
|
|
286 |
if is_minmax: |
|
|
287 |
reconstructed_im = minmax_normalize(reconstructed_im, eps) |
|
|
288 |
target_im = minmax_normalize(target_im, eps) |
|
|
289 |
|
|
|
290 |
ssim_value = structural_similarity(target_im, reconstructed_im, \ |
|
|
291 |
gaussian_weights=True, sigma=1.5, use_sample_covariance=False) |
|
|
292 |
|
|
|
293 |
return ssim_value |