a b/FastRCNN/utils/visualization.py
1
import numpy as np
2
3
4
def vis_image(img, ax=None):
5
    """Visualize a color image.
6
7
    Args:
8
        img (~numpy.ndarray): An array of shape :math:`(3, height, width)`.
9
            This is in RGB format and the range of its value is
10
            :math:`[0, 255]`.
11
        ax (matplotlib.axes.Axis): The visualization is displayed on this
12
            axis. If this is :obj:`None` (default), a new axis is created.
13
14
    Returns:
15
        ~matploblib.axes.Axes:
16
        Returns the Axes object with the plot for further tweaking.
17
18
    """
19
    from matplotlib import pyplot as plot
20
    if ax is None:
21
        fig = plot.figure()
22
        ax = fig.add_subplot(1, 1, 1)
23
    # CHW -> HWC
24
    img = img.transpose((1, 2, 0))[:, :, 0]
25
    ax.imshow(img.astype(np.uint8), cmap='gray')
26
    return ax
27
28
29
def vis_bbox(img, bbox, label=None, score=None, label_names=None, ax=None):
30
    """Visualize bounding boxes inside image.
31
32
    Example:
33
34
        >>> from chainercv.datasets import VOCDetectionDataset
35
        >>> from chainercv.datasets import voc_bbox_label_names
36
        >>> from chainercv.visualizations import vis_bbox
37
        >>> import matplotlib.pyplot as plot
38
        >>> dataset = VOCDetectionDataset()
39
        >>> img, bbox, label = dataset[60]
40
        >>> vis_bbox(img, bbox, label,
41
        ...         label_names=voc_bbox_label_names)
42
        >>> plot.show()
43
44
    Args:
45
        img (~numpy.ndarray): An array of shape :math:`(3, height, width)`.
46
            This is in RGB format and the range of its value is
47
            :math:`[0, 255]`.
48
        bbox (~numpy.ndarray): An array of shape :math:`(R, 4)`, where
49
            :math:`R` is the number of bounding boxes in the image.
50
            Each element is organized
51
            by :math:`(y_{min}, x_{min}, y_{max}, x_{max})` in the second axis.
52
        label (~numpy.ndarray): An integer array of shape :math:`(R,)`.
53
            The values correspond to id for label names stored in
54
            :obj:`label_names`. This is optional.
55
        score (~numpy.ndarray): A float array of shape :math:`(R,)`.
56
             Each value indicates how confident the prediction is.
57
             This is optional.
58
        label_names (iterable of strings): Name of labels ordered according
59
            to label ids. If this is :obj:`None`, labels will be skipped.
60
        ax (matplotlib.axes.Axis): The visualization is displayed on this
61
            axis. If this is :obj:`None` (default), a new axis is created.
62
63
    Returns:
64
        ~matploblib.axes.Axes:
65
        Returns the Axes object with the plot for further tweaking.
66
67
    """
68
    from matplotlib import pyplot as plot
69
70
    #if label is not None and not len(bbox) == len(label):
71
    #    raise ValueError('The length of label must be same as that of bbox')
72
    #if score is not None and not len(bbox) == len(score):
73
    #    raise ValueError('The length of score must be same as that of bbox')
74
75
    # Returns newly instantiated matplotlib.axes.Axes object if ax is None
76
    ax = vis_image(img, ax=ax)
77
78
    # If there is no bounding box to display, visualize the image and exit.
79
    if len(bbox) == 0:
80
        return ax
81
82
    if len(bbox) != 0:
83
        #print(bbox)
84
        bb = bbox[0]
85
        i = 0
86
        xy = (bb[1], bb[0])
87
        height = bb[2] - bb[0]
88
        width = bb[3] - bb[1]
89
        ax.add_patch(plot.Rectangle(
90
            xy, width, height, fill=False, edgecolor='red', linewidth=3))
91
92
        caption = list()
93
94
        if label is not None and label_names is not None:
95
            lb = label[i]
96
            if not (0 <= lb < len(label_names)):
97
                raise ValueError('No corresponding name is given')
98
            caption.append(label_names[lb])
99
        if score is not None:
100
            sc = score[i]
101
            caption.append('{:.2f}'.format(sc))
102
103
        if len(caption) > 0:
104
            ax.text(bb[1], bb[0],
105
                    ': '.join(caption),
106
                    style='italic')
107
#                     bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
108
    return ax