--- a
+++ b/opengait/modeling/base_model.py
@@ -0,0 +1,468 @@
+"""The base model definition.
+
+This module defines the abstract meta model class and base model class. In the base model,
+ we define the basic model functions, like get_loader, build_network, and run_train, etc.
+ The api of the base model is run_train and run_test, they are used in `opengait/main.py`.
+
+Typical usage:
+
+BaseModel.run_train(model)
+BaseModel.run_test(model)
+"""
+import torch
+import numpy as np
+import os.path as osp
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data as tordata
+
+from tqdm import tqdm
+from torch.cuda.amp import autocast
+from torch.cuda.amp import GradScaler
+from abc import ABCMeta
+from abc import abstractmethod
+
+from . import backbones
+from .loss_aggregator import LossAggregator
+from data.transform import get_transform
+from data.collate_fn import CollateFn
+from data.dataset import DataSet
+import data.sampler as Samplers
+from utils import Odict, mkdir, ddp_all_gather
+from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
+from evaluation import evaluator as eval_functions
+from utils import NoOp
+from utils import get_msg_mgr
+
+__all__ = ['BaseModel']
+
+
+class MetaModel(metaclass=ABCMeta):
+    """The necessary functions for the base model.
+
+    This class defines the necessary functions for the base model, in the base model, we have implemented them.
+    """
+    @abstractmethod
+    def get_loader(self, data_cfg):
+        """Based on the given data_cfg, we get the data loader."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def build_network(self, model_cfg):
+        """Build your network here."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def init_parameters(self):
+        """Initialize the parameters of your network."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_optimizer(self, optimizer_cfg):
+        """Based on the given optimizer_cfg, we get the optimizer."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_scheduler(self, scheduler_cfg):
+        """Based on the given scheduler_cfg, we get the scheduler."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def save_ckpt(self, iteration):
+        """Save the checkpoint, including model parameter, optimizer and scheduler."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def resume_ckpt(self, restore_hint):
+        """Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def inputs_pretreament(self, inputs):
+        """Transform the input data based on transform setting."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def train_step(self, loss_num) -> bool:
+        """Do one training step."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def inference(self):
+        """Do inference (calculate features.)."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def run_train(model):
+        """Run a whole train schedule."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def run_test(model):
+        """Run a whole test schedule."""
+        raise NotImplementedError
+
+
+class BaseModel(MetaModel, nn.Module):
+    """Base model.
+
+    This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
+
+    Attributes:
+        msg_mgr: the massage manager.
+        cfgs: the configs.
+        iteration: the current iteration of the model.
+        engine_cfg: the configs of the engine(train or test).
+        save_path: the path to save the checkpoints.
+
+    """
+
+    def __init__(self, cfgs, training):
+        """Initialize the base model.
+
+        Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
+
+        Args:
+        cfgs:
+            All of the configs.
+        training:
+            Whether the model is in training mode.
+        """
+
+        super(BaseModel, self).__init__()
+        self.msg_mgr = get_msg_mgr()
+        self.cfgs = cfgs
+        self.iteration = 0
+        self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
+        if self.engine_cfg is None:
+            raise Exception("Initialize a model without -Engine-Cfgs-")
+
+        if training and self.engine_cfg['enable_float16']:
+            self.Scaler = GradScaler()
+        self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
+                                  cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
+
+        self.build_network(cfgs['model_cfg'])
+        self.init_parameters()
+        self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform'])
+
+        self.msg_mgr.log_info(cfgs['data_cfg'])
+        if training:
+            self.train_loader = self.get_loader(
+                cfgs['data_cfg'], train=True)
+        if not training or self.engine_cfg['with_test']:
+            self.test_loader = self.get_loader(
+                cfgs['data_cfg'], train=False)
+            self.evaluator_trfs = get_transform(
+                cfgs['evaluator_cfg']['transform'])
+
+        self.device = torch.distributed.get_rank()
+        torch.cuda.set_device(self.device)
+        self.to(device=torch.device(
+            "cuda", self.device))
+
+        if training:
+            self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
+            self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
+            self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
+        self.train(training)
+        restore_hint = self.engine_cfg['restore_hint']
+        if restore_hint != 0:
+            self.resume_ckpt(restore_hint)
+
+    def get_backbone(self, backbone_cfg):
+        """Get the backbone of the model."""
+        if is_dict(backbone_cfg):
+            Backbone = get_attr_from([backbones], backbone_cfg['type'])
+            valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
+            return Backbone(**valid_args)
+        if is_list(backbone_cfg):
+            Backbone = nn.ModuleList([self.get_backbone(cfg)
+                                      for cfg in backbone_cfg])
+            return Backbone
+        raise ValueError(
+            "Error type for -Backbone-Cfg-, supported: (A list of) dict.")
+
+    def build_network(self, model_cfg):
+        if 'backbone_cfg' in model_cfg.keys():
+            self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
+
+    def init_parameters(self):
+        for m in self.modules():
+            if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
+                nn.init.xavier_uniform_(m.weight.data)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias.data, 0.0)
+            elif isinstance(m, nn.Linear):
+                nn.init.xavier_uniform_(m.weight.data)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias.data, 0.0)
+            elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
+                if m.affine:
+                    nn.init.normal_(m.weight.data, 1.0, 0.02)
+                    nn.init.constant_(m.bias.data, 0.0)
+
+    def get_loader(self, data_cfg, train=True):
+        sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler']
+        dataset = DataSet(data_cfg, train)
+
+        Sampler = get_attr_from([Samplers], sampler_cfg['type'])
+        vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[
+            'sample_type', 'type'])
+        sampler = Sampler(dataset, **vaild_args)
+
+        loader = tordata.DataLoader(
+            dataset=dataset,
+            batch_sampler=sampler,
+            collate_fn=CollateFn(dataset.label_set, sampler_cfg),
+            num_workers=data_cfg['num_workers'])
+        return loader
+
+    def get_optimizer(self, optimizer_cfg):
+        self.msg_mgr.log_info(optimizer_cfg)
+        optimizer = get_attr_from([optim], optimizer_cfg['solver'])
+        valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
+        optimizer = optimizer(
+            filter(lambda p: p.requires_grad, self.parameters()), **valid_arg)
+        return optimizer
+
+    def get_scheduler(self, scheduler_cfg):
+        self.msg_mgr.log_info(scheduler_cfg)
+        Scheduler = get_attr_from(
+            [optim.lr_scheduler], scheduler_cfg['scheduler'])
+        valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler'])
+        scheduler = Scheduler(self.optimizer, **valid_arg)
+        return scheduler
+
+    def save_ckpt(self, iteration):
+        if torch.distributed.get_rank() == 0:
+            mkdir(osp.join(self.save_path, "checkpoints/"))
+            save_name = self.engine_cfg['save_name']
+            checkpoint = {
+                'model': self.state_dict(),
+                'optimizer': self.optimizer.state_dict(),
+                'scheduler': self.scheduler.state_dict(),
+                'iteration': iteration}
+            torch.save(checkpoint,
+                       osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
+
+    def _load_ckpt(self, save_name):
+        load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
+
+        checkpoint = torch.load(save_name, map_location=torch.device(
+            "cuda", self.device))
+        model_state_dict = checkpoint['model']
+
+        if not load_ckpt_strict:
+            self.msg_mgr.log_info("-------- Restored Params List --------")
+            self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection(
+                set(self.state_dict().keys()))))
+
+        self.load_state_dict(model_state_dict, strict=load_ckpt_strict)
+        if self.training:
+            if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint:
+                self.optimizer.load_state_dict(checkpoint['optimizer'])
+            else:
+                self.msg_mgr.log_warning(
+                    "Restore NO Optimizer from %s !!!" % save_name)
+            if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint:
+                self.scheduler.load_state_dict(
+                    checkpoint['scheduler'])
+            else:
+                self.msg_mgr.log_warning(
+                    "Restore NO Scheduler from %s !!!" % save_name)
+        self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
+
+    def resume_ckpt(self, restore_hint):
+        if isinstance(restore_hint, int):
+            save_name = self.engine_cfg['save_name']
+            save_name = osp.join(
+                self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
+            self.iteration = restore_hint
+        elif isinstance(restore_hint, str):
+            save_name = restore_hint
+            self.iteration = 0
+        else:
+            raise ValueError(
+                "Error type for -Restore_Hint-, supported: int or string.")
+        self._load_ckpt(save_name)
+
+    def fix_BN(self):
+        for module in self.modules():
+            classname = module.__class__.__name__
+            if classname.find('BatchNorm') != -1:
+                module.eval()
+
+    def inputs_pretreament(self, inputs):
+        """Conduct transforms on input data.
+
+        Args:
+            inputs: the input data.
+        Returns:
+            tuple: training data including inputs, labels, and some meta data.
+        """
+        seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
+        seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs
+        if len(seqs_batch) != len(seq_trfs):
+            raise ValueError(
+                "The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
+        requires_grad = bool(self.training)
+        seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
+                for trf, seq in zip(seq_trfs, seqs_batch)]
+
+        typs = typs_batch
+        vies = vies_batch
+
+        labs = list2var(labs_batch).long()
+
+        if seqL_batch is not None:
+            seqL_batch = np2var(seqL_batch).int()
+        seqL = seqL_batch
+
+        if seqL is not None:
+            seqL_sum = int(seqL.sum().data.cpu().numpy())
+            ipts = [_[:, :seqL_sum] for _ in seqs]
+        else:
+            ipts = seqs
+        del seqs
+        return ipts, labs, typs, vies, seqL
+
+    def train_step(self, loss_sum) -> bool:
+        """Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
+
+        Args:
+            loss_sum:The loss of the current batch.
+        Returns:
+            bool: True if the training is finished, False otherwise.
+        """
+
+        self.optimizer.zero_grad()
+        if loss_sum <= 1e-9:
+            self.msg_mgr.log_warning(
+                "Find the loss sum less than 1e-9 but the training process will continue!")
+
+        if self.engine_cfg['enable_float16']:
+            self.Scaler.scale(loss_sum).backward()
+            self.Scaler.step(self.optimizer)
+            scale = self.Scaler.get_scale()
+            self.Scaler.update()
+            # Warning caused by optimizer skip when NaN
+            # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
+            if scale != self.Scaler.get_scale():
+                self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
+                    scale, self.Scaler.get_scale()))
+                return False
+        else:
+            loss_sum.backward()
+            self.optimizer.step()
+
+        self.iteration += 1
+        self.scheduler.step()
+        return True
+
+    def inference(self, rank):
+        """Inference all the test data.
+
+        Args:
+            rank: the rank of the current process.Transform
+        Returns:
+            Odict: contains the inference results.
+        """
+        total_size = len(self.test_loader)
+        if rank == 0:
+            pbar = tqdm(total=total_size, desc='Transforming')
+        else:
+            pbar = NoOp()
+        batch_size = self.test_loader.batch_sampler.batch_size
+        rest_size = total_size
+        info_dict = Odict()
+        for inputs in self.test_loader:
+            ipts = self.inputs_pretreament(inputs)
+            with autocast(enabled=self.engine_cfg['enable_float16']):
+                retval = self.forward(ipts)
+                inference_feat = retval['inference_feat']
+                for k, v in inference_feat.items():
+                    inference_feat[k] = ddp_all_gather(v, requires_grad=False)
+                del retval
+            for k, v in inference_feat.items():
+                inference_feat[k] = ts2np(v)
+            info_dict.append(inference_feat)
+            rest_size -= batch_size
+            if rest_size >= 0:
+                update_size = batch_size
+            else:
+                update_size = total_size % batch_size
+            pbar.update(update_size)
+        pbar.close()
+        for k, v in info_dict.items():
+            v = np.concatenate(v)[:total_size]
+            info_dict[k] = v
+        return info_dict
+
+    @ staticmethod
+    def run_train(model):
+        """Accept the instance object(model) here, and then run the train loop."""
+        for inputs in model.train_loader:
+            ipts = model.inputs_pretreament(inputs)
+            with autocast(enabled=model.engine_cfg['enable_float16']):
+                retval = model(ipts)
+                training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
+                del retval
+            loss_sum, loss_info = model.loss_aggregator(training_feat)
+            ok = model.train_step(loss_sum)
+            if not ok:
+                continue
+
+            visual_summary.update(loss_info)
+            visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
+
+            model.msg_mgr.train_step(loss_info, visual_summary)
+            if model.iteration % model.engine_cfg['save_iter'] == 0:
+                # save the checkpoint
+                model.save_ckpt(model.iteration)
+
+                # run test if with_test = true
+                if model.engine_cfg['with_test']:
+                    model.msg_mgr.log_info("Running test...")
+                    model.eval()
+                    result_dict = BaseModel.run_test(model)
+                    model.train()
+                    if model.cfgs['trainer_cfg']['fix_BN']:
+                        model.fix_BN()
+                    if result_dict:
+                        model.msg_mgr.write_to_tensorboard(result_dict)
+                    model.msg_mgr.reset_time()
+            if model.iteration >= model.engine_cfg['total_iter']:
+                break
+
+    @ staticmethod
+    def run_test(model):
+        """Accept the instance object(model) here, and then run the test loop."""
+        evaluator_cfg = model.cfgs['evaluator_cfg']
+        if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
+            raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
+                evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
+        rank = torch.distributed.get_rank()
+        with torch.no_grad():
+            info_dict = model.inference(rank)
+        if rank == 0:
+            loader = model.test_loader
+            label_list = loader.dataset.label_list
+            types_list = loader.dataset.types_list
+            views_list = loader.dataset.views_list
+
+            info_dict.update({
+                'labels': label_list, 'types': types_list, 'views': views_list})
+
+            if 'eval_func' in evaluator_cfg.keys():
+                eval_func = evaluator_cfg["eval_func"]
+            else:
+                eval_func = 'identification'
+            eval_func = getattr(eval_functions, eval_func)
+            valid_args = get_valid_args(
+                eval_func, evaluator_cfg, ['metric'])
+            try:
+                dataset_name = model.cfgs['data_cfg']['test_dataset_name']
+            except:
+                dataset_name = model.cfgs['data_cfg']['dataset_name']
+            return eval_func(info_dict, dataset_name, **valid_args)