import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
class GLN(BaseModel):
"""
http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
"""
def build_network(self, model_cfg):
in_channels = model_cfg['in_channels']
self.bin_num = model_cfg['bin_num']
self.hidden_dim = model_cfg['hidden_dim']
lateral_dim = model_cfg['lateral_dim']
reduce_dim = self.hidden_dim
self.pretrain = model_cfg['Lateral_pretraining']
self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
nn.LeakyReLU(inplace=True),
BasicConv2d(
in_channels[1], in_channels[1], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
nn.LeakyReLU(inplace=True),
BasicConv2d(
in_channels[2], in_channels[2], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
nn.LeakyReLU(inplace=True),
BasicConv2d(
in_channels[3], in_channels[3], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
self.set_pooling = PackSequenceWrapper(torch.max)
self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
self.lateral_layer1 = nn.Conv2d(
in_channels[1]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.lateral_layer2 = nn.Conv2d(
in_channels[2]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.lateral_layer3 = nn.Conv2d(
in_channels[3]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.smooth_layer1 = nn.Conv2d(
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.smooth_layer2 = nn.Conv2d(
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.smooth_layer3 = nn.Conv2d(
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.HPP = HorizontalPoolingPyramid()
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
if not self.pretrain:
self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
self.encoder_bn.bias.requires_grad_(False)
self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
self.reduce_ac = nn.ReLU(inplace=True)
self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
self.reduce_bn = nn.BatchNorm1d(reduce_dim)
self.reduce_bn.bias.requires_grad_(False)
self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
def upsample_add(self, x, y):
return F.interpolate(x, scale_factor=2, mode='nearest') + y
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
sils = ipts[0] # [n, s, h, w]
del ipts
if len(sils.size()) == 4:
sils = sils.unsqueeze(1)
n, _, s, h, w = sils.size()
### stage 0 sil ###
sil_0_outs = self.sil_stage_0(sils)
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0]
### stage 1 sil ###
sil_1_ipts = self.MaxP_sil(sil_0_outs)
sil_1_outs = self.sil_stage_1(sil_1_ipts)
### stage 2 sil ###
sil_2_ipts = self.MaxP_sil(sil_1_outs)
sil_2_outs = self.sil_stage_2(sil_2_ipts)
### stage 1 set ###
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0]
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0]
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
### stage 2 set ###
set_2_ipts = self.MaxP_set(set_1_outs)
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0]
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
# print(set1.shape,set2.shape,set3.shape,"***\n")
# lateral
set3 = self.lateral_layer3(set3)
set2 = self.upsample_add(set3, self.lateral_layer2(set2))
set1 = self.upsample_add(set2, self.lateral_layer1(set1))
set3 = self.smooth_layer3(set3)
set2 = self.smooth_layer2(set2)
set1 = self.smooth_layer1(set1)
set1 = self.HPP(set1)
set2 = self.HPP(set2)
set3 = self.HPP(set3)
feature = torch.cat([set1, set2, set3], -1)
feature = self.Head(feature)
# compact_bloack
if not self.pretrain:
bn_feature = self.encoder_bn(feature.view(n, -1))
bn_feature = bn_feature.view(*feature.shape).contiguous()
reduce_feature = self.reduce_dp(bn_feature)
reduce_feature = self.reduce_ac(reduce_feature)
reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
bn_reduce_feature = self.reduce_bn(reduce_feature)
logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1) # n c
reduce_feature = reduce_feature.unsqueeze(1).contiguous()
bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
retval = {
'training_feat': {},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': feature # reduce_feature # bn_reduce_feature
}
}
if self.pretrain:
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
else:
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
return retval