[40f229]: / work / OUMVLP_network / gaitset.py

Download this file

88 lines (75 with data), 3.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class SetNet(nn.Module):
def __init__(self, hidden_dim):
super(SetNet, self).__init__()
self.hidden_dim = hidden_dim
self.batch_frame = None
_in_channels = 1
_channels = [64,128,256]
self.set_layer1 = SetBlock(BasicConv2d(_in_channels, _channels[0], 5, padding=2))
self.set_layer2 = SetBlock(BasicConv2d(_channels[0], _channels[0], 3, padding=1), True)
self.set_layer3 = SetBlock(BasicConv2d(_channels[0], _channels[1], 3, padding=1))
self.set_layer4 = SetBlock(BasicConv2d(_channels[1], _channels[1], 3, padding=1), True)
self.set_layer5 = SetBlock(BasicConv2d(_channels[1], _channels[2], 3, padding=1))
self.set_layer6 = SetBlock(BasicConv2d(_channels[2], _channels[2], 3, padding=1))
self.gl_layer1 = BasicConv2d(_channels[0], _channels[1], 3, padding=1)
self.gl_layer2 = BasicConv2d(_channels[1], _channels[1], 3, padding=1)
self.gl_layer3 = BasicConv2d(_channels[1], _channels[2], 3, padding=1)
self.gl_layer4 = BasicConv2d(_channels[2], _channels[2], 3, padding=1)
self.gl_pooling = nn.MaxPool2d(2)
self.gl_hpm = HPM(_channels[-1], hidden_dim)
self.x_hpm = HPM(_channels[-1], hidden_dim)
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform(m.weight.data)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform(m.weight.data)
nn.init.constant(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.normal(m.weight.data, 1.0, 0.02)
nn.init.constant(m.bias.data, 0.0)
def frame_max(self, x):
if self.batch_frame is None:
return torch.max(x, 1)
else:
_tmp = [
torch.max(x[:, self.batch_frame[i]:self.batch_frame[i+1], :, :, :], 1)
for i in range(len(self.batch_frame)-1)
]
max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
return max_list, arg_max_list
def forward(self, silho, batch_frame=None):
silho = silho/255
# n: batch_size, s: frame_num, k: keypoints_num, c: channel
if batch_frame is not None:
batch_frame = batch_frame[0].data.cpu().numpy().tolist()
_ = len(batch_frame)
for i in range(len(batch_frame)):
if batch_frame[-(i+1)] != 0:
break
else:
_ -= 1
batch_frame = batch_frame[:_]
frame_sum = np.sum(batch_frame)
if frame_sum < silho.size(1):
silho = silho[:, :frame_sum,:,:]
self.batch_frame = [0]+np.cumsum(batch_frame).tolist()
n = silho.size(0)
x = silho.unsqueeze(2)
del silho
x = self.set_layer1(x)
x = self.set_layer2(x)
gl = self.gl_layer1(self.frame_max(x)[0])
gl = self.gl_layer2(gl)
gl = self.gl_pooling(gl)
x = self.set_layer3(x)
x = self.set_layer4(x)
gl = self.gl_layer3(gl+self.frame_max(x)[0])
gl = self.gl_layer4(gl)
x = self.set_layer5(x)
x = self.set_layer6(x)
x = self.frame_max(x)[0]
gl = gl+x
gl_f = self.gl_hpm(gl)
x_f = self.x_hpm(x)
return torch.cat([gl_f, x_f], 1), None