--- a +++ b/CaraNet/utils/utils.py @@ -0,0 +1,62 @@ +import torch +import numpy as np +from thop import profile +from thop import clever_format + + +def clip_gradient(optimizer, grad_clip): + """ + For calibrating misalignment gradient via cliping gradient technique + :param optimizer: + :param grad_clip: + :return: + """ + for group in optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.clamp_(-grad_clip, grad_clip) + + +def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): + decay = decay_rate ** (epoch // decay_epoch) + for param_group in optimizer.param_groups: + param_group['lr'] *= decay + + +class AvgMeter(object): + def __init__(self, num=40): + self.num = num + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.losses = [] + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + self.losses.append(val) + + def show(self): + return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) + + +def CalParams(model, input_tensor): + """ + Usage: + Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) + Necessarity: + from thop import profile + from thop import clever_format + :param model: + :param input_tensor: + :return: + """ + flops, params = profile(model, inputs=(input_tensor,)) + flops, params = clever_format([flops, params], "%.3f") + print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) \ No newline at end of file