Switch to side-by-side view

--- a
+++ b/opengait/modeling/models/gln.py
@@ -0,0 +1,169 @@
+import torch
+import copy
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..base_model import BaseModel
+from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
+
+
+class GLN(BaseModel):
+    """
+        http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
+        Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
+    """
+
+    def build_network(self, model_cfg):
+        in_channels = model_cfg['in_channels']
+        self.bin_num = model_cfg['bin_num']
+        self.hidden_dim = model_cfg['hidden_dim']
+        lateral_dim = model_cfg['lateral_dim']
+        reduce_dim = self.hidden_dim
+        self.pretrain = model_cfg['Lateral_pretraining']
+
+        self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
+                                         nn.LeakyReLU(inplace=True),
+                                         BasicConv2d(
+                                             in_channels[1], in_channels[1], 3, 1, 1),
+                                         nn.LeakyReLU(inplace=True))
+
+        self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
+                                         nn.LeakyReLU(inplace=True),
+                                         BasicConv2d(
+                                             in_channels[2], in_channels[2], 3, 1, 1),
+                                         nn.LeakyReLU(inplace=True))
+
+        self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
+                                         nn.LeakyReLU(inplace=True),
+                                         BasicConv2d(
+                                             in_channels[3], in_channels[3], 3, 1, 1),
+                                         nn.LeakyReLU(inplace=True))
+
+        self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
+        self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
+
+        self.set_pooling = PackSequenceWrapper(torch.max)
+
+        self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
+        self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
+
+        self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
+        self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
+        self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
+
+        self.lateral_layer1 = nn.Conv2d(
+            in_channels[1]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
+        self.lateral_layer2 = nn.Conv2d(
+            in_channels[2]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
+        self.lateral_layer3 = nn.Conv2d(
+            in_channels[3]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
+
+        self.smooth_layer1 = nn.Conv2d(
+            lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
+        self.smooth_layer2 = nn.Conv2d(
+            lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
+        self.smooth_layer3 = nn.Conv2d(
+            lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
+
+        self.HPP = HorizontalPoolingPyramid()
+        self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
+
+        if not self.pretrain:
+            self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
+            self.encoder_bn.bias.requires_grad_(False)
+
+            self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
+            self.reduce_ac = nn.ReLU(inplace=True)
+            self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
+
+            self.reduce_bn = nn.BatchNorm1d(reduce_dim)
+            self.reduce_bn.bias.requires_grad_(False)
+
+            self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
+
+    def upsample_add(self, x, y):
+        return F.interpolate(x, scale_factor=2, mode='nearest') + y
+
+    def forward(self, inputs):
+        ipts, labs, _, _, seqL = inputs
+        sils = ipts[0]  # [n, s, h, w]
+        del ipts
+        if len(sils.size()) == 4:
+            sils = sils.unsqueeze(1)
+        n, _, s, h, w = sils.size()
+
+        ### stage 0 sil ###
+        sil_0_outs = self.sil_stage_0(sils)
+        stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0]
+
+        ### stage 1 sil ###
+        sil_1_ipts = self.MaxP_sil(sil_0_outs)
+        sil_1_outs = self.sil_stage_1(sil_1_ipts)
+
+        ### stage 2 sil ###
+        sil_2_ipts = self.MaxP_sil(sil_1_outs)
+        sil_2_outs = self.sil_stage_2(sil_2_ipts)
+
+        ### stage 1 set ###
+        set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0]
+        stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0]
+        set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
+
+        ### stage 2 set ###
+        set_2_ipts = self.MaxP_set(set_1_outs)
+        stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0]
+        set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
+
+        set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
+        set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
+        set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
+
+        # print(set1.shape,set2.shape,set3.shape,"***\n")
+
+        # lateral 
+        set3 = self.lateral_layer3(set3)
+        set2 = self.upsample_add(set3, self.lateral_layer2(set2))
+        set1 = self.upsample_add(set2, self.lateral_layer1(set1))
+
+        set3 = self.smooth_layer3(set3)
+        set2 = self.smooth_layer2(set2)
+        set1 = self.smooth_layer1(set1)
+
+        set1 = self.HPP(set1)
+        set2 = self.HPP(set2)
+        set3 = self.HPP(set3)
+
+        feature = torch.cat([set1, set2, set3], -1)
+
+        feature = self.Head(feature)
+
+        # compact_bloack
+        if not self.pretrain:
+            bn_feature = self.encoder_bn(feature.view(n, -1))
+            bn_feature = bn_feature.view(*feature.shape).contiguous()
+
+            reduce_feature = self.reduce_dp(bn_feature)
+            reduce_feature = self.reduce_ac(reduce_feature)
+            reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
+
+            bn_reduce_feature = self.reduce_bn(reduce_feature)
+            logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1)  # n c
+
+            reduce_feature = reduce_feature.unsqueeze(1).contiguous()
+            bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
+
+        retval = {
+            'training_feat': {},
+            'visual_summary': {
+                'image/sils': sils.view(n*s, 1, h, w)
+            },
+            'inference_feat': {
+                'embeddings':  feature  # reduce_feature # bn_reduce_feature
+            }
+        }
+        if self.pretrain:
+            retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
+        else:
+            retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
+            retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
+        return retval