|
a |
|
b/opengait/modeling/models/gaitset.py |
|
|
1 |
import torch |
|
|
2 |
import copy |
|
|
3 |
import torch.nn as nn |
|
|
4 |
|
|
|
5 |
from ..base_model import BaseModel |
|
|
6 |
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class GaitSet(BaseModel): |
|
|
10 |
""" |
|
|
11 |
GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition |
|
|
12 |
Arxiv: https://arxiv.org/abs/1811.06186 |
|
|
13 |
Github: https://github.com/AbnerHqC/GaitSet |
|
|
14 |
""" |
|
|
15 |
|
|
|
16 |
def build_network(self, model_cfg): |
|
|
17 |
in_c = model_cfg['in_channels'] |
|
|
18 |
self.set_block1 = nn.Sequential(BasicConv2d(in_c[0], in_c[1], 5, 1, 2), |
|
|
19 |
nn.LeakyReLU(inplace=True), |
|
|
20 |
BasicConv2d(in_c[1], in_c[1], 3, 1, 1), |
|
|
21 |
nn.LeakyReLU(inplace=True), |
|
|
22 |
nn.MaxPool2d(kernel_size=2, stride=2)) |
|
|
23 |
|
|
|
24 |
self.set_block2 = nn.Sequential(BasicConv2d(in_c[1], in_c[2], 3, 1, 1), |
|
|
25 |
nn.LeakyReLU(inplace=True), |
|
|
26 |
BasicConv2d(in_c[2], in_c[2], 3, 1, 1), |
|
|
27 |
nn.LeakyReLU(inplace=True), |
|
|
28 |
nn.MaxPool2d(kernel_size=2, stride=2)) |
|
|
29 |
|
|
|
30 |
self.set_block3 = nn.Sequential(BasicConv2d(in_c[2], in_c[3], 3, 1, 1), |
|
|
31 |
nn.LeakyReLU(inplace=True), |
|
|
32 |
BasicConv2d(in_c[3], in_c[3], 3, 1, 1), |
|
|
33 |
nn.LeakyReLU(inplace=True)) |
|
|
34 |
|
|
|
35 |
self.gl_block2 = copy.deepcopy(self.set_block2) |
|
|
36 |
self.gl_block3 = copy.deepcopy(self.set_block3) |
|
|
37 |
|
|
|
38 |
self.set_block1 = SetBlockWrapper(self.set_block1) |
|
|
39 |
self.set_block2 = SetBlockWrapper(self.set_block2) |
|
|
40 |
self.set_block3 = SetBlockWrapper(self.set_block3) |
|
|
41 |
|
|
|
42 |
self.set_pooling = PackSequenceWrapper(torch.max) |
|
|
43 |
|
|
|
44 |
self.Head = SeparateFCs(**model_cfg['SeparateFCs']) |
|
|
45 |
|
|
|
46 |
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']) |
|
|
47 |
|
|
|
48 |
def forward(self, inputs): |
|
|
49 |
ipts, labs, _, _, seqL = inputs |
|
|
50 |
sils = ipts[0] # [n, s, h, w] |
|
|
51 |
if len(sils.size()) == 4: |
|
|
52 |
sils = sils.unsqueeze(1) |
|
|
53 |
|
|
|
54 |
del ipts |
|
|
55 |
outs = self.set_block1(sils) |
|
|
56 |
gl = self.set_pooling(outs, seqL, options={"dim": 2})[0] |
|
|
57 |
gl = self.gl_block2(gl) |
|
|
58 |
|
|
|
59 |
outs = self.set_block2(outs) |
|
|
60 |
gl = gl + self.set_pooling(outs, seqL, options={"dim": 2})[0] |
|
|
61 |
gl = self.gl_block3(gl) |
|
|
62 |
|
|
|
63 |
outs = self.set_block3(outs) |
|
|
64 |
outs = self.set_pooling(outs, seqL, options={"dim": 2})[0] |
|
|
65 |
gl = gl + outs |
|
|
66 |
|
|
|
67 |
# Horizontal Pooling Matching, HPM |
|
|
68 |
feature1 = self.HPP(outs) # [n, c, p] |
|
|
69 |
feature2 = self.HPP(gl) # [n, c, p] |
|
|
70 |
feature = torch.cat([feature1, feature2], -1) # [n, c, p] |
|
|
71 |
embs = self.Head(feature) |
|
|
72 |
|
|
|
73 |
n, _, s, h, w = sils.size() |
|
|
74 |
retval = { |
|
|
75 |
'training_feat': { |
|
|
76 |
'triplet': {'embeddings': embs, 'labels': labs} |
|
|
77 |
}, |
|
|
78 |
'visual_summary': { |
|
|
79 |
'image/sils': sils.view(n*s, 1, h, w) |
|
|
80 |
}, |
|
|
81 |
'inference_feat': { |
|
|
82 |
'embeddings': embs |
|
|
83 |
} |
|
|
84 |
} |
|
|
85 |
return retval |