a b/U-Net/utils/utils.py
1
import matplotlib.pyplot as plt
2
3
4
def plot_img_and_mask(img, mask):
5
    classes = mask.max() + 1
6
    fig, ax = plt.subplots(1, classes + 1)
7
    ax[0].set_title('Input image')
8
    ax[0].imshow(img)
9
    for i in range(classes):
10
        ax[i + 1].set_title(f'Mask (class {i + 1})')
11
        ax[i + 1].imshow(mask == i)
12
    plt.xticks([]), plt.yticks([])
13
    plt.show()