|
a |
|
b/CaraNet/utils/utils.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
from thop import profile |
|
|
4 |
from thop import clever_format |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def clip_gradient(optimizer, grad_clip): |
|
|
8 |
""" |
|
|
9 |
For calibrating misalignment gradient via cliping gradient technique |
|
|
10 |
:param optimizer: |
|
|
11 |
:param grad_clip: |
|
|
12 |
:return: |
|
|
13 |
""" |
|
|
14 |
for group in optimizer.param_groups: |
|
|
15 |
for param in group['params']: |
|
|
16 |
if param.grad is not None: |
|
|
17 |
param.grad.data.clamp_(-grad_clip, grad_clip) |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): |
|
|
21 |
decay = decay_rate ** (epoch // decay_epoch) |
|
|
22 |
for param_group in optimizer.param_groups: |
|
|
23 |
param_group['lr'] *= decay |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
class AvgMeter(object): |
|
|
27 |
def __init__(self, num=40): |
|
|
28 |
self.num = num |
|
|
29 |
self.reset() |
|
|
30 |
|
|
|
31 |
def reset(self): |
|
|
32 |
self.val = 0 |
|
|
33 |
self.avg = 0 |
|
|
34 |
self.sum = 0 |
|
|
35 |
self.count = 0 |
|
|
36 |
self.losses = [] |
|
|
37 |
|
|
|
38 |
def update(self, val, n=1): |
|
|
39 |
self.val = val |
|
|
40 |
self.sum += val * n |
|
|
41 |
self.count += n |
|
|
42 |
self.avg = self.sum / self.count |
|
|
43 |
self.losses.append(val) |
|
|
44 |
|
|
|
45 |
def show(self): |
|
|
46 |
return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def CalParams(model, input_tensor): |
|
|
50 |
""" |
|
|
51 |
Usage: |
|
|
52 |
Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) |
|
|
53 |
Necessarity: |
|
|
54 |
from thop import profile |
|
|
55 |
from thop import clever_format |
|
|
56 |
:param model: |
|
|
57 |
:param input_tensor: |
|
|
58 |
:return: |
|
|
59 |
""" |
|
|
60 |
flops, params = profile(model, inputs=(input_tensor,)) |
|
|
61 |
flops, params = clever_format([flops, params], "%.3f") |
|
|
62 |
print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) |