Diff of /model/network/gaitset.py [000000] .. [40f229]

Switch to side-by-side view

--- a
+++ b/model/network/gaitset.py
@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+from .basic_blocks import SetBlock, BasicConv2d
+
+
+class SetNet(nn.Module):
+    def __init__(self, hidden_dim):
+        super(SetNet, self).__init__()
+        self.hidden_dim = hidden_dim
+        self.batch_frame = None
+
+        _set_in_channels = 1
+        _set_channels = [32, 64, 128]
+        self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, padding=2))
+        self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, padding=1), True)
+        self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 3, padding=1))
+        self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 3, padding=1), True)
+        self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, padding=1))
+        self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, padding=1))
+
+        _gl_in_channels = 32
+        _gl_channels = [64, 128]
+        self.gl_layer1 = BasicConv2d(_gl_in_channels, _gl_channels[0], 3, padding=1)
+        self.gl_layer2 = BasicConv2d(_gl_channels[0], _gl_channels[0], 3, padding=1)
+        self.gl_layer3 = BasicConv2d(_gl_channels[0], _gl_channels[1], 3, padding=1)
+        self.gl_layer4 = BasicConv2d(_gl_channels[1], _gl_channels[1], 3, padding=1)
+        self.gl_pooling = nn.MaxPool2d(2)
+
+        self.bin_num = [1, 2, 4, 8, 16]
+        self.fc_bin = nn.ParameterList([
+            nn.Parameter(
+                nn.init.xavier_uniform_(
+                    torch.zeros(sum(self.bin_num) * 2, 128, hidden_dim)))])
+
+        for m in self.modules():
+            if isinstance(m, (nn.Conv2d, nn.Conv1d)):
+                nn.init.xavier_uniform_(m.weight.data)
+            elif isinstance(m, nn.Linear):
+                nn.init.xavier_uniform_(m.weight.data)
+                nn.init.constant(m.bias.data, 0.0)
+            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
+                nn.init.normal(m.weight.data, 1.0, 0.02)
+                nn.init.constant(m.bias.data, 0.0)
+
+    def frame_max(self, x):
+        if self.batch_frame is None:
+            return torch.max(x, 1)
+        else:
+            _tmp = [
+                torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
+                for i in range(len(self.batch_frame) - 1)
+                ]
+            max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
+            arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
+            return max_list, arg_max_list
+
+    def frame_median(self, x):
+        if self.batch_frame is None:
+            return torch.median(x, 1)
+        else:
+            _tmp = [
+                torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
+                for i in range(len(self.batch_frame) - 1)
+                ]
+            median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
+            arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
+            return median_list, arg_median_list
+
+    def forward(self, silho, batch_frame=None):
+        # n: batch_size, s: frame_num, k: keypoints_num, c: channel
+        if batch_frame is not None:
+            batch_frame = batch_frame[0].data.cpu().numpy().tolist()
+            _ = len(batch_frame)
+            for i in range(len(batch_frame)):
+                if batch_frame[-(i + 1)] != 0:
+                    break
+                else:
+                    _ -= 1
+            batch_frame = batch_frame[:_]
+            frame_sum = np.sum(batch_frame)
+            if frame_sum < silho.size(1):
+                silho = silho[:, :frame_sum, :, :]
+            self.batch_frame = [0] + np.cumsum(batch_frame).tolist()
+        n = silho.size(0)
+        x = silho.unsqueeze(2)
+        del silho
+
+        x = self.set_layer1(x)
+        x = self.set_layer2(x)
+        gl = self.gl_layer1(self.frame_max(x)[0])
+        gl = self.gl_layer2(gl)
+        gl = self.gl_pooling(gl)
+
+        x = self.set_layer3(x)
+        x = self.set_layer4(x)
+        gl = self.gl_layer3(gl + self.frame_max(x)[0])
+        gl = self.gl_layer4(gl)
+
+        x = self.set_layer5(x)
+        x = self.set_layer6(x)
+        x = self.frame_max(x)[0]
+        gl = gl + x
+
+        feature = list()
+        n, c, h, w = gl.size()
+        for num_bin in self.bin_num:
+            z = x.view(n, c, num_bin, -1)
+            z = z.mean(3) + z.max(3)[0]
+            feature.append(z)
+            z = gl.view(n, c, num_bin, -1)
+            z = z.mean(3) + z.max(3)[0]
+            feature.append(z)
+        feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous()
+
+        feature = feature.matmul(self.fc_bin[0])
+        feature = feature.permute(1, 0, 2).contiguous()
+
+        return feature, None