Diff of /CaraNet/utils/utils.py [000000] .. [6f3ba0]

Switch to unified view

a b/CaraNet/utils/utils.py
1
import torch
2
import numpy as np
3
from thop import profile
4
from thop import clever_format
5
6
7
def clip_gradient(optimizer, grad_clip):
8
    """
9
    For calibrating misalignment gradient via cliping gradient technique
10
    :param optimizer:
11
    :param grad_clip:
12
    :return:
13
    """
14
    for group in optimizer.param_groups:
15
        for param in group['params']:
16
            if param.grad is not None:
17
                param.grad.data.clamp_(-grad_clip, grad_clip)
18
19
20
def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
21
    decay = decay_rate ** (epoch // decay_epoch)
22
    for param_group in optimizer.param_groups:
23
        param_group['lr'] *= decay
24
25
26
class AvgMeter(object):
27
    def __init__(self, num=40):
28
        self.num = num
29
        self.reset()
30
31
    def reset(self):
32
        self.val = 0
33
        self.avg = 0
34
        self.sum = 0
35
        self.count = 0
36
        self.losses = []
37
38
    def update(self, val, n=1):
39
        self.val = val
40
        self.sum += val * n
41
        self.count += n
42
        self.avg = self.sum / self.count
43
        self.losses.append(val)
44
45
    def show(self):
46
        return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):]))
47
48
49
def CalParams(model, input_tensor):
50
    """
51
    Usage:
52
        Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter)
53
    Necessarity:
54
        from thop import profile
55
        from thop import clever_format
56
    :param model:
57
    :param input_tensor:
58
    :return:
59
    """
60
    flops, params = profile(model, inputs=(input_tensor,))
61
    flops, params = clever_format([flops, params], "%.3f")
62
    print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params))