[40f229]: / model / network / gaitset.py

Download this file

121 lines (104 with data), 4.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import numpy as np
from .basic_blocks import SetBlock, BasicConv2d
class SetNet(nn.Module):
def __init__(self, hidden_dim):
super(SetNet, self).__init__()
self.hidden_dim = hidden_dim
self.batch_frame = None
_set_in_channels = 1
_set_channels = [32, 64, 128]
self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, padding=2))
self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, padding=1), True)
self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 3, padding=1))
self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 3, padding=1), True)
self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, padding=1))
self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, padding=1))
_gl_in_channels = 32
_gl_channels = [64, 128]
self.gl_layer1 = BasicConv2d(_gl_in_channels, _gl_channels[0], 3, padding=1)
self.gl_layer2 = BasicConv2d(_gl_channels[0], _gl_channels[0], 3, padding=1)
self.gl_layer3 = BasicConv2d(_gl_channels[0], _gl_channels[1], 3, padding=1)
self.gl_layer4 = BasicConv2d(_gl_channels[1], _gl_channels[1], 3, padding=1)
self.gl_pooling = nn.MaxPool2d(2)
self.bin_num = [1, 2, 4, 8, 16]
self.fc_bin = nn.ParameterList([
nn.Parameter(
nn.init.xavier_uniform_(
torch.zeros(sum(self.bin_num) * 2, 128, hidden_dim)))])
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight.data)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
nn.init.constant(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.normal(m.weight.data, 1.0, 0.02)
nn.init.constant(m.bias.data, 0.0)
def frame_max(self, x):
if self.batch_frame is None:
return torch.max(x, 1)
else:
_tmp = [
torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
for i in range(len(self.batch_frame) - 1)
]
max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
return max_list, arg_max_list
def frame_median(self, x):
if self.batch_frame is None:
return torch.median(x, 1)
else:
_tmp = [
torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
for i in range(len(self.batch_frame) - 1)
]
median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
return median_list, arg_median_list
def forward(self, silho, batch_frame=None):
# n: batch_size, s: frame_num, k: keypoints_num, c: channel
if batch_frame is not None:
batch_frame = batch_frame[0].data.cpu().numpy().tolist()
_ = len(batch_frame)
for i in range(len(batch_frame)):
if batch_frame[-(i + 1)] != 0:
break
else:
_ -= 1
batch_frame = batch_frame[:_]
frame_sum = np.sum(batch_frame)
if frame_sum < silho.size(1):
silho = silho[:, :frame_sum, :, :]
self.batch_frame = [0] + np.cumsum(batch_frame).tolist()
n = silho.size(0)
x = silho.unsqueeze(2)
del silho
x = self.set_layer1(x)
x = self.set_layer2(x)
gl = self.gl_layer1(self.frame_max(x)[0])
gl = self.gl_layer2(gl)
gl = self.gl_pooling(gl)
x = self.set_layer3(x)
x = self.set_layer4(x)
gl = self.gl_layer3(gl + self.frame_max(x)[0])
gl = self.gl_layer4(gl)
x = self.set_layer5(x)
x = self.set_layer6(x)
x = self.frame_max(x)[0]
gl = gl + x
feature = list()
n, c, h, w = gl.size()
for num_bin in self.bin_num:
z = x.view(n, c, num_bin, -1)
z = z.mean(3) + z.max(3)[0]
feature.append(z)
z = gl.view(n, c, num_bin, -1)
z = z.mean(3) + z.max(3)[0]
feature.append(z)
feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous()
feature = feature.matmul(self.fc_bin[0])
feature = feature.permute(1, 0, 2).contiguous()
return feature, None