Switch to side-by-side view

--- a
+++ b/opengait/modeling/models/deepgaitv2.py
@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+
+import os
+import numpy as np
+import os.path as osp
+import matplotlib.pyplot as plt
+
+from ..base_model import BaseModel
+from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, BasicBlock2D, BasicBlockP3D, BasicBlock3D
+
+from einops import rearrange
+
+blocks_map = {
+    '2d': BasicBlock2D, 
+    'p3d': BasicBlockP3D, 
+    '3d': BasicBlock3D
+}
+
+class DeepGaitV2(BaseModel):
+
+    def build_network(self, model_cfg):
+        mode = model_cfg['Backbone']['mode']
+        assert mode in blocks_map.keys()
+        block = blocks_map[mode]
+
+        in_channels = model_cfg['Backbone']['in_channels']
+        layers      = model_cfg['Backbone']['layers']
+        channels    = model_cfg['Backbone']['channels']
+        self.inference_use_emb2 = model_cfg['use_emb2'] if 'use_emb2' in model_cfg else False
+
+        if mode == '3d': 
+            strides = [
+                [1, 1], 
+                [1, 2, 2], 
+                [1, 2, 2], 
+                [1, 1, 1]
+            ]
+        else: 
+            strides = [
+                [1, 1], 
+                [2, 2], 
+                [2, 2], 
+                [1, 1]
+            ]
+
+        self.inplanes = channels[0]
+        self.layer0 = SetBlockWrapper(nn.Sequential(
+            conv3x3(in_channels, self.inplanes, 1), 
+            nn.BatchNorm2d(self.inplanes), 
+            nn.ReLU(inplace=True)
+        ))
+        self.layer1 = SetBlockWrapper(self.make_layer(BasicBlock2D, channels[0], strides[0], blocks_num=layers[0], mode=mode))
+
+        self.layer2 = self.make_layer(block, channels[1], strides[1], blocks_num=layers[1], mode=mode)
+        self.layer3 = self.make_layer(block, channels[2], strides[2], blocks_num=layers[2], mode=mode)
+        self.layer4 = self.make_layer(block, channels[3], strides[3], blocks_num=layers[3], mode=mode)
+
+        if mode == '2d': 
+            self.layer2 = SetBlockWrapper(self.layer2)
+            self.layer3 = SetBlockWrapper(self.layer3)
+            self.layer4 = SetBlockWrapper(self.layer4)
+
+        self.FCs = SeparateFCs(16, channels[3], channels[2])
+        self.BNNecks = SeparateBNNecks(16, channels[2], class_num=model_cfg['SeparateBNNecks']['class_num'])
+
+        self.TP = PackSequenceWrapper(torch.max)
+        self.HPP = HorizontalPoolingPyramid(bin_num=[16])
+
+    def make_layer(self, block, planes, stride, blocks_num, mode='2d'):
+
+        if max(stride) > 1 or self.inplanes != planes * block.expansion:
+            if mode == '3d':
+                downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=[1, 1, 1], stride=stride, padding=[0, 0, 0], bias=False), nn.BatchNorm3d(planes * block.expansion))
+            elif mode == '2d':
+                downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride=stride), nn.BatchNorm2d(planes * block.expansion))
+            elif mode == 'p3d':
+                downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=[1, 1, 1], stride=[1, *stride], padding=[0, 0, 0], bias=False), nn.BatchNorm3d(planes * block.expansion))
+            else:
+                raise TypeError('xxx')
+        else:
+            downsample = lambda x: x
+
+        layers = [block(self.inplanes, planes, stride=stride, downsample=downsample)]
+        self.inplanes = planes * block.expansion
+        s = [1, 1] if mode in ['2d', 'p3d'] else [1, 1, 1]
+        for i in range(1, blocks_num):
+            layers.append(
+                    block(self.inplanes, planes, stride=s)
+            )
+        return nn.Sequential(*layers)
+
+    def forward(self, inputs):
+        ipts, labs, typs, vies, seqL = inputs
+        
+        if len(ipts[0].size()) == 4:
+            sils = ipts[0].unsqueeze(1)
+        else:
+            sils = ipts[0]
+            sils = sils.transpose(1, 2).contiguous()
+        assert sils.size(-1) in [44, 88]
+
+        del ipts
+        out0 = self.layer0(sils)
+        out1 = self.layer1(out0)
+        out2 = self.layer2(out1)
+        out3 = self.layer3(out2)
+        out4 = self.layer4(out3) # [n, c, s, h, w]
+
+        # Temporal Pooling, TP
+        outs = self.TP(out4, seqL, options={"dim": 2})[0]  # [n, c, h, w]
+
+        # Horizontal Pooling Matching, HPM
+        feat = self.HPP(outs)  # [n, c, p]
+
+        embed_1 = self.FCs(feat)  # [n, c, p]
+        embed_2, logits = self.BNNecks(embed_1)  # [n, c, p]
+
+        if self.inference_use_emb2:
+                embed = embed_2
+        else:
+                embed = embed_1
+
+        retval = {
+            'training_feat': {
+                'triplet': {'embeddings': embed_1, 'labels': labs},
+                'softmax': {'logits': logits, 'labels': labs}
+            },
+            'visual_summary': {
+                'image/sils': rearrange(sils, 'n c s h w -> (n s) c h w'),
+            },
+            'inference_feat': {
+                'embeddings': embed
+            }
+        }
+
+        return retval