a b/opengait/modeling/models/gln.py
1
import torch
2
import copy
3
import torch.nn as nn
4
import torch.nn.functional as F
5
6
from ..base_model import BaseModel
7
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
8
9
10
class GLN(BaseModel):
11
    """
12
        http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
13
        Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
14
    """
15
16
    def build_network(self, model_cfg):
17
        in_channels = model_cfg['in_channels']
18
        self.bin_num = model_cfg['bin_num']
19
        self.hidden_dim = model_cfg['hidden_dim']
20
        lateral_dim = model_cfg['lateral_dim']
21
        reduce_dim = self.hidden_dim
22
        self.pretrain = model_cfg['Lateral_pretraining']
23
24
        self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
25
                                         nn.LeakyReLU(inplace=True),
26
                                         BasicConv2d(
27
                                             in_channels[1], in_channels[1], 3, 1, 1),
28
                                         nn.LeakyReLU(inplace=True))
29
30
        self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
31
                                         nn.LeakyReLU(inplace=True),
32
                                         BasicConv2d(
33
                                             in_channels[2], in_channels[2], 3, 1, 1),
34
                                         nn.LeakyReLU(inplace=True))
35
36
        self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
37
                                         nn.LeakyReLU(inplace=True),
38
                                         BasicConv2d(
39
                                             in_channels[3], in_channels[3], 3, 1, 1),
40
                                         nn.LeakyReLU(inplace=True))
41
42
        self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
43
        self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
44
45
        self.set_pooling = PackSequenceWrapper(torch.max)
46
47
        self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
48
        self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
49
50
        self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
51
        self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
52
        self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
53
54
        self.lateral_layer1 = nn.Conv2d(
55
            in_channels[1]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
56
        self.lateral_layer2 = nn.Conv2d(
57
            in_channels[2]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
58
        self.lateral_layer3 = nn.Conv2d(
59
            in_channels[3]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
60
61
        self.smooth_layer1 = nn.Conv2d(
62
            lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
63
        self.smooth_layer2 = nn.Conv2d(
64
            lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
65
        self.smooth_layer3 = nn.Conv2d(
66
            lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
67
68
        self.HPP = HorizontalPoolingPyramid()
69
        self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
70
71
        if not self.pretrain:
72
            self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
73
            self.encoder_bn.bias.requires_grad_(False)
74
75
            self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
76
            self.reduce_ac = nn.ReLU(inplace=True)
77
            self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
78
79
            self.reduce_bn = nn.BatchNorm1d(reduce_dim)
80
            self.reduce_bn.bias.requires_grad_(False)
81
82
            self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
83
84
    def upsample_add(self, x, y):
85
        return F.interpolate(x, scale_factor=2, mode='nearest') + y
86
87
    def forward(self, inputs):
88
        ipts, labs, _, _, seqL = inputs
89
        sils = ipts[0]  # [n, s, h, w]
90
        del ipts
91
        if len(sils.size()) == 4:
92
            sils = sils.unsqueeze(1)
93
        n, _, s, h, w = sils.size()
94
95
        ### stage 0 sil ###
96
        sil_0_outs = self.sil_stage_0(sils)
97
        stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0]
98
99
        ### stage 1 sil ###
100
        sil_1_ipts = self.MaxP_sil(sil_0_outs)
101
        sil_1_outs = self.sil_stage_1(sil_1_ipts)
102
103
        ### stage 2 sil ###
104
        sil_2_ipts = self.MaxP_sil(sil_1_outs)
105
        sil_2_outs = self.sil_stage_2(sil_2_ipts)
106
107
        ### stage 1 set ###
108
        set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0]
109
        stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0]
110
        set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
111
112
        ### stage 2 set ###
113
        set_2_ipts = self.MaxP_set(set_1_outs)
114
        stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0]
115
        set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
116
117
        set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
118
        set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
119
        set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
120
121
        # print(set1.shape,set2.shape,set3.shape,"***\n")
122
123
        # lateral 
124
        set3 = self.lateral_layer3(set3)
125
        set2 = self.upsample_add(set3, self.lateral_layer2(set2))
126
        set1 = self.upsample_add(set2, self.lateral_layer1(set1))
127
128
        set3 = self.smooth_layer3(set3)
129
        set2 = self.smooth_layer2(set2)
130
        set1 = self.smooth_layer1(set1)
131
132
        set1 = self.HPP(set1)
133
        set2 = self.HPP(set2)
134
        set3 = self.HPP(set3)
135
136
        feature = torch.cat([set1, set2, set3], -1)
137
138
        feature = self.Head(feature)
139
140
        # compact_bloack
141
        if not self.pretrain:
142
            bn_feature = self.encoder_bn(feature.view(n, -1))
143
            bn_feature = bn_feature.view(*feature.shape).contiguous()
144
145
            reduce_feature = self.reduce_dp(bn_feature)
146
            reduce_feature = self.reduce_ac(reduce_feature)
147
            reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
148
149
            bn_reduce_feature = self.reduce_bn(reduce_feature)
150
            logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1)  # n c
151
152
            reduce_feature = reduce_feature.unsqueeze(1).contiguous()
153
            bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
154
155
        retval = {
156
            'training_feat': {},
157
            'visual_summary': {
158
                'image/sils': sils.view(n*s, 1, h, w)
159
            },
160
            'inference_feat': {
161
                'embeddings':  feature  # reduce_feature # bn_reduce_feature
162
            }
163
        }
164
        if self.pretrain:
165
            retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
166
        else:
167
            retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
168
            retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
169
        return retval