Switch to unified view

a b/opengait/modeling/models/gaitset.py
1
import torch
2
import copy
3
import torch.nn as nn
4
5
from ..base_model import BaseModel
6
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
7
8
9
class GaitSet(BaseModel):
10
    """
11
        GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition
12
        Arxiv:  https://arxiv.org/abs/1811.06186
13
        Github: https://github.com/AbnerHqC/GaitSet
14
    """
15
16
    def build_network(self, model_cfg):
17
        in_c = model_cfg['in_channels']
18
        self.set_block1 = nn.Sequential(BasicConv2d(in_c[0], in_c[1], 5, 1, 2),
19
                                        nn.LeakyReLU(inplace=True),
20
                                        BasicConv2d(in_c[1], in_c[1], 3, 1, 1),
21
                                        nn.LeakyReLU(inplace=True),
22
                                        nn.MaxPool2d(kernel_size=2, stride=2))
23
24
        self.set_block2 = nn.Sequential(BasicConv2d(in_c[1], in_c[2], 3, 1, 1),
25
                                        nn.LeakyReLU(inplace=True),
26
                                        BasicConv2d(in_c[2], in_c[2], 3, 1, 1),
27
                                        nn.LeakyReLU(inplace=True),
28
                                        nn.MaxPool2d(kernel_size=2, stride=2))
29
30
        self.set_block3 = nn.Sequential(BasicConv2d(in_c[2], in_c[3], 3, 1, 1),
31
                                        nn.LeakyReLU(inplace=True),
32
                                        BasicConv2d(in_c[3], in_c[3], 3, 1, 1),
33
                                        nn.LeakyReLU(inplace=True))
34
35
        self.gl_block2 = copy.deepcopy(self.set_block2)
36
        self.gl_block3 = copy.deepcopy(self.set_block3)
37
38
        self.set_block1 = SetBlockWrapper(self.set_block1)
39
        self.set_block2 = SetBlockWrapper(self.set_block2)
40
        self.set_block3 = SetBlockWrapper(self.set_block3)
41
42
        self.set_pooling = PackSequenceWrapper(torch.max)
43
44
        self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
45
46
        self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
47
48
    def forward(self, inputs):
49
        ipts, labs, _, _, seqL = inputs
50
        sils = ipts[0]  # [n, s, h, w]
51
        if len(sils.size()) == 4:
52
            sils = sils.unsqueeze(1)
53
54
        del ipts
55
        outs = self.set_block1(sils)
56
        gl = self.set_pooling(outs, seqL, options={"dim": 2})[0]
57
        gl = self.gl_block2(gl)
58
59
        outs = self.set_block2(outs)
60
        gl = gl + self.set_pooling(outs, seqL, options={"dim": 2})[0]
61
        gl = self.gl_block3(gl)
62
63
        outs = self.set_block3(outs)
64
        outs = self.set_pooling(outs, seqL, options={"dim": 2})[0]
65
        gl = gl + outs
66
67
        # Horizontal Pooling Matching, HPM
68
        feature1 = self.HPP(outs)  # [n, c, p]
69
        feature2 = self.HPP(gl)  # [n, c, p]
70
        feature = torch.cat([feature1, feature2], -1)  # [n, c, p]
71
        embs = self.Head(feature)
72
73
        n, _, s, h, w = sils.size()
74
        retval = {
75
            'training_feat': {
76
                'triplet': {'embeddings': embs, 'labels': labs}
77
            },
78
            'visual_summary': {
79
                'image/sils': sils.view(n*s, 1, h, w)
80
            },
81
            'inference_feat': {
82
                'embeddings': embs
83
            }
84
        }
85
        return retval