a | b/U-Net/utils/utils.py | ||
---|---|---|---|
1 | import matplotlib.pyplot as plt |
||
2 | |||
3 | |||
4 | def plot_img_and_mask(img, mask): |
||
5 | classes = mask.max() + 1 |
||
6 | fig, ax = plt.subplots(1, classes + 1) |
||
7 | ax[0].set_title('Input image') |
||
8 | ax[0].imshow(img) |
||
9 | for i in range(classes): |
||
10 | ax[i + 1].set_title(f'Mask (class {i + 1})') |
||
11 | ax[i + 1].imshow(mask == i) |
||
12 | plt.xticks([]), plt.yticks([]) |
||
13 | plt.show() |