"""The base model definition.
This module defines the abstract meta model class and base model class. In the base model,
we define the basic model functions, like get_loader, build_network, and run_train, etc.
The api of the base model is run_train and run_test, they are used in `opengait/main.py`.
Typical usage:
BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import torch
import numpy as np
import os.path as osp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as tordata
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
from abc import ABCMeta
from abc import abstractmethod
from . import backbones
from .loss_aggregator import LossAggregator
from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from utils import Odict, mkdir, ddp_all_gather
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from evaluation import evaluator as eval_functions
from utils import NoOp
from utils import get_msg_mgr
__all__ = ['BaseModel']
class MetaModel(metaclass=ABCMeta):
"""The necessary functions for the base model.
This class defines the necessary functions for the base model, in the base model, we have implemented them.
"""
@abstractmethod
def get_loader(self, data_cfg):
"""Based on the given data_cfg, we get the data loader."""
raise NotImplementedError
@abstractmethod
def build_network(self, model_cfg):
"""Build your network here."""
raise NotImplementedError
@abstractmethod
def init_parameters(self):
"""Initialize the parameters of your network."""
raise NotImplementedError
@abstractmethod
def get_optimizer(self, optimizer_cfg):
"""Based on the given optimizer_cfg, we get the optimizer."""
raise NotImplementedError
@abstractmethod
def get_scheduler(self, scheduler_cfg):
"""Based on the given scheduler_cfg, we get the scheduler."""
raise NotImplementedError
@abstractmethod
def save_ckpt(self, iteration):
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def resume_ckpt(self, restore_hint):
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def inputs_pretreament(self, inputs):
"""Transform the input data based on transform setting."""
raise NotImplementedError
@abstractmethod
def train_step(self, loss_num) -> bool:
"""Do one training step."""
raise NotImplementedError
@abstractmethod
def inference(self):
"""Do inference (calculate features.)."""
raise NotImplementedError
@abstractmethod
def run_train(model):
"""Run a whole train schedule."""
raise NotImplementedError
@abstractmethod
def run_test(model):
"""Run a whole test schedule."""
raise NotImplementedError
class BaseModel(MetaModel, nn.Module):
"""Base model.
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
Attributes:
msg_mgr: the massage manager.
cfgs: the configs.
iteration: the current iteration of the model.
engine_cfg: the configs of the engine(train or test).
save_path: the path to save the checkpoints.
"""
def __init__(self, cfgs, training):
"""Initialize the base model.
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
Args:
cfgs:
All of the configs.
training:
Whether the model is in training mode.
"""
super(BaseModel, self).__init__()
self.msg_mgr = get_msg_mgr()
self.cfgs = cfgs
self.iteration = 0
self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
if self.engine_cfg is None:
raise Exception("Initialize a model without -Engine-Cfgs-")
if training and self.engine_cfg['enable_float16']:
self.Scaler = GradScaler()
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
self.build_network(cfgs['model_cfg'])
self.init_parameters()
self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform'])
self.msg_mgr.log_info(cfgs['data_cfg'])
if training:
self.train_loader = self.get_loader(
cfgs['data_cfg'], train=True)
if not training or self.engine_cfg['with_test']:
self.test_loader = self.get_loader(
cfgs['data_cfg'], train=False)
self.evaluator_trfs = get_transform(
cfgs['evaluator_cfg']['transform'])
self.device = torch.distributed.get_rank()
torch.cuda.set_device(self.device)
self.to(device=torch.device(
"cuda", self.device))
if training:
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
self.train(training)
restore_hint = self.engine_cfg['restore_hint']
if restore_hint != 0:
self.resume_ckpt(restore_hint)
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
def build_network(self, model_cfg):
if 'backbone_cfg' in model_cfg.keys():
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
def init_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
if m.affine:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
def get_loader(self, data_cfg, train=True):
sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler']
dataset = DataSet(data_cfg, train)
Sampler = get_attr_from([Samplers], sampler_cfg['type'])
vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[
'sample_type', 'type'])
sampler = Sampler(dataset, **vaild_args)
loader = tordata.DataLoader(
dataset=dataset,
batch_sampler=sampler,
collate_fn=CollateFn(dataset.label_set, sampler_cfg),
num_workers=data_cfg['num_workers'])
return loader
def get_optimizer(self, optimizer_cfg):
self.msg_mgr.log_info(optimizer_cfg)
optimizer = get_attr_from([optim], optimizer_cfg['solver'])
valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
optimizer = optimizer(
filter(lambda p: p.requires_grad, self.parameters()), **valid_arg)
return optimizer
def get_scheduler(self, scheduler_cfg):
self.msg_mgr.log_info(scheduler_cfg)
Scheduler = get_attr_from(
[optim.lr_scheduler], scheduler_cfg['scheduler'])
valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler'])
scheduler = Scheduler(self.optimizer, **valid_arg)
return scheduler
def save_ckpt(self, iteration):
if torch.distributed.get_rank() == 0:
mkdir(osp.join(self.save_path, "checkpoints/"))
save_name = self.engine_cfg['save_name']
checkpoint = {
'model': self.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'iteration': iteration}
torch.save(checkpoint,
osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
def _load_ckpt(self, save_name):
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
checkpoint = torch.load(save_name, map_location=torch.device(
"cuda", self.device))
model_state_dict = checkpoint['model']
if not load_ckpt_strict:
self.msg_mgr.log_info("-------- Restored Params List --------")
self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection(
set(self.state_dict().keys()))))
self.load_state_dict(model_state_dict, strict=load_ckpt_strict)
if self.training:
if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
else:
self.msg_mgr.log_warning(
"Restore NO Optimizer from %s !!!" % save_name)
if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint:
self.scheduler.load_state_dict(
checkpoint['scheduler'])
else:
self.msg_mgr.log_warning(
"Restore NO Scheduler from %s !!!" % save_name)
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
def resume_ckpt(self, restore_hint):
if isinstance(restore_hint, int):
save_name = self.engine_cfg['save_name']
save_name = osp.join(
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
self.iteration = restore_hint
elif isinstance(restore_hint, str):
save_name = restore_hint
self.iteration = 0
else:
raise ValueError(
"Error type for -Restore_Hint-, supported: int or string.")
self._load_ckpt(save_name)
def fix_BN(self):
for module in self.modules():
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1:
module.eval()
def inputs_pretreament(self, inputs):
"""Conduct transforms on input data.
Args:
inputs: the input data.
Returns:
tuple: training data including inputs, labels, and some meta data.
"""
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs
if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
requires_grad = bool(self.training)
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
for trf, seq in zip(seq_trfs, seqs_batch)]
typs = typs_batch
vies = vies_batch
labs = list2var(labs_batch).long()
if seqL_batch is not None:
seqL_batch = np2var(seqL_batch).int()
seqL = seqL_batch
if seqL is not None:
seqL_sum = int(seqL.sum().data.cpu().numpy())
ipts = [_[:, :seqL_sum] for _ in seqs]
else:
ipts = seqs
del seqs
return ipts, labs, typs, vies, seqL
def train_step(self, loss_sum) -> bool:
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
Args:
loss_sum:The loss of the current batch.
Returns:
bool: True if the training is finished, False otherwise.
"""
self.optimizer.zero_grad()
if loss_sum <= 1e-9:
self.msg_mgr.log_warning(
"Find the loss sum less than 1e-9 but the training process will continue!")
if self.engine_cfg['enable_float16']:
self.Scaler.scale(loss_sum).backward()
self.Scaler.step(self.optimizer)
scale = self.Scaler.get_scale()
self.Scaler.update()
# Warning caused by optimizer skip when NaN
# https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
if scale != self.Scaler.get_scale():
self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
scale, self.Scaler.get_scale()))
return False
else:
loss_sum.backward()
self.optimizer.step()
self.iteration += 1
self.scheduler.step()
return True
def inference(self, rank):
"""Inference all the test data.
Args:
rank: the rank of the current process.Transform
Returns:
Odict: contains the inference results.
"""
total_size = len(self.test_loader)
if rank == 0:
pbar = tqdm(total=total_size, desc='Transforming')
else:
pbar = NoOp()
batch_size = self.test_loader.batch_sampler.batch_size
rest_size = total_size
info_dict = Odict()
for inputs in self.test_loader:
ipts = self.inputs_pretreament(inputs)
with autocast(enabled=self.engine_cfg['enable_float16']):
retval = self.forward(ipts)
inference_feat = retval['inference_feat']
for k, v in inference_feat.items():
inference_feat[k] = ddp_all_gather(v, requires_grad=False)
del retval
for k, v in inference_feat.items():
inference_feat[k] = ts2np(v)
info_dict.append(inference_feat)
rest_size -= batch_size
if rest_size >= 0:
update_size = batch_size
else:
update_size = total_size % batch_size
pbar.update(update_size)
pbar.close()
for k, v in info_dict.items():
v = np.concatenate(v)[:total_size]
info_dict[k] = v
return info_dict
@ staticmethod
def run_train(model):
"""Accept the instance object(model) here, and then run the train loop."""
for inputs in model.train_loader:
ipts = model.inputs_pretreament(inputs)
with autocast(enabled=model.engine_cfg['enable_float16']):
retval = model(ipts)
training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
del retval
loss_sum, loss_info = model.loss_aggregator(training_feat)
ok = model.train_step(loss_sum)
if not ok:
continue
visual_summary.update(loss_info)
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
model.msg_mgr.train_step(loss_info, visual_summary)
if model.iteration % model.engine_cfg['save_iter'] == 0:
# save the checkpoint
model.save_ckpt(model.iteration)
# run test if with_test = true
if model.engine_cfg['with_test']:
model.msg_mgr.log_info("Running test...")
model.eval()
result_dict = BaseModel.run_test(model)
model.train()
if model.cfgs['trainer_cfg']['fix_BN']:
model.fix_BN()
if result_dict:
model.msg_mgr.write_to_tensorboard(result_dict)
model.msg_mgr.reset_time()
if model.iteration >= model.engine_cfg['total_iter']:
break
@ staticmethod
def run_test(model):
"""Accept the instance object(model) here, and then run the test loop."""
evaluator_cfg = model.cfgs['evaluator_cfg']
if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
rank = torch.distributed.get_rank()
with torch.no_grad():
info_dict = model.inference(rank)
if rank == 0:
loader = model.test_loader
label_list = loader.dataset.label_list
types_list = loader.dataset.types_list
views_list = loader.dataset.views_list
info_dict.update({
'labels': label_list, 'types': types_list, 'views': views_list})
if 'eval_func' in evaluator_cfg.keys():
eval_func = evaluator_cfg["eval_func"]
else:
eval_func = 'identification'
eval_func = getattr(eval_functions, eval_func)
valid_args = get_valid_args(
eval_func, evaluator_cfg, ['metric'])
try:
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
except:
dataset_name = model.cfgs['data_cfg']['dataset_name']
return eval_func(info_dict, dataset_name, **valid_args)