--- a +++ b/model/network/gaitset.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import numpy as np + +from .basic_blocks import SetBlock, BasicConv2d + + +class SetNet(nn.Module): + def __init__(self, hidden_dim): + super(SetNet, self).__init__() + self.hidden_dim = hidden_dim + self.batch_frame = None + + _set_in_channels = 1 + _set_channels = [32, 64, 128] + self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, padding=2)) + self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, padding=1), True) + self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 3, padding=1)) + self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 3, padding=1), True) + self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, padding=1)) + self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, padding=1)) + + _gl_in_channels = 32 + _gl_channels = [64, 128] + self.gl_layer1 = BasicConv2d(_gl_in_channels, _gl_channels[0], 3, padding=1) + self.gl_layer2 = BasicConv2d(_gl_channels[0], _gl_channels[0], 3, padding=1) + self.gl_layer3 = BasicConv2d(_gl_channels[0], _gl_channels[1], 3, padding=1) + self.gl_layer4 = BasicConv2d(_gl_channels[1], _gl_channels[1], 3, padding=1) + self.gl_pooling = nn.MaxPool2d(2) + + self.bin_num = [1, 2, 4, 8, 16] + self.fc_bin = nn.ParameterList([ + nn.Parameter( + nn.init.xavier_uniform_( + torch.zeros(sum(self.bin_num) * 2, 128, hidden_dim)))]) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv1d)): + nn.init.xavier_uniform_(m.weight.data) + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight.data) + nn.init.constant(m.bias.data, 0.0) + elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + def frame_max(self, x): + if self.batch_frame is None: + return torch.max(x, 1) + else: + _tmp = [ + torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) + for i in range(len(self.batch_frame) - 1) + ] + max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) + arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) + return max_list, arg_max_list + + def frame_median(self, x): + if self.batch_frame is None: + return torch.median(x, 1) + else: + _tmp = [ + torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) + for i in range(len(self.batch_frame) - 1) + ] + median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) + arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) + return median_list, arg_median_list + + def forward(self, silho, batch_frame=None): + # n: batch_size, s: frame_num, k: keypoints_num, c: channel + if batch_frame is not None: + batch_frame = batch_frame[0].data.cpu().numpy().tolist() + _ = len(batch_frame) + for i in range(len(batch_frame)): + if batch_frame[-(i + 1)] != 0: + break + else: + _ -= 1 + batch_frame = batch_frame[:_] + frame_sum = np.sum(batch_frame) + if frame_sum < silho.size(1): + silho = silho[:, :frame_sum, :, :] + self.batch_frame = [0] + np.cumsum(batch_frame).tolist() + n = silho.size(0) + x = silho.unsqueeze(2) + del silho + + x = self.set_layer1(x) + x = self.set_layer2(x) + gl = self.gl_layer1(self.frame_max(x)[0]) + gl = self.gl_layer2(gl) + gl = self.gl_pooling(gl) + + x = self.set_layer3(x) + x = self.set_layer4(x) + gl = self.gl_layer3(gl + self.frame_max(x)[0]) + gl = self.gl_layer4(gl) + + x = self.set_layer5(x) + x = self.set_layer6(x) + x = self.frame_max(x)[0] + gl = gl + x + + feature = list() + n, c, h, w = gl.size() + for num_bin in self.bin_num: + z = x.view(n, c, num_bin, -1) + z = z.mean(3) + z.max(3)[0] + feature.append(z) + z = gl.view(n, c, num_bin, -1) + z = z.mean(3) + z.max(3)[0] + feature.append(z) + feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous() + + feature = feature.matmul(self.fc_bin[0]) + feature = feature.permute(1, 0, 2).contiguous() + + return feature, None