|
a |
|
b/utils.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
''' |
|
|
3 |
@time: 2019/9/12 15:16 |
|
|
4 |
|
|
|
5 |
@ author: javis |
|
|
6 |
''' |
|
|
7 |
import torch |
|
|
8 |
import numpy as np |
|
|
9 |
import time,os |
|
|
10 |
from sklearn.metrics import f1_score |
|
|
11 |
from torch import nn |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def mkdirs(path): |
|
|
15 |
if not os.path.exists(path): |
|
|
16 |
os.makedirs(path) |
|
|
17 |
|
|
|
18 |
#计算F1score |
|
|
19 |
def calc_f1(y_true, y_pre, threshold=0.5): |
|
|
20 |
y_true = y_true.view(-1).cpu().detach().numpy().astype(np.int) |
|
|
21 |
y_pre = y_pre.view(-1).cpu().detach().numpy() > threshold |
|
|
22 |
return f1_score(y_true, y_pre) |
|
|
23 |
|
|
|
24 |
#打印时间 |
|
|
25 |
def print_time_cost(since): |
|
|
26 |
time_elapsed = time.time() - since |
|
|
27 |
return '{:.0f}m{:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
# 调整学习率 |
|
|
31 |
def adjust_learning_rate(optimizer, lr): |
|
|
32 |
for param_group in optimizer.param_groups: |
|
|
33 |
param_group['lr'] = lr |
|
|
34 |
return lr |
|
|
35 |
|
|
|
36 |
#多标签使用类别权重 |
|
|
37 |
class WeightedMultilabel(nn.Module): |
|
|
38 |
def __init__(self, weights: torch.Tensor): |
|
|
39 |
super(WeightedMultilabel, self).__init__() |
|
|
40 |
self.cerition = nn.BCEWithLogitsLoss(reduction='none') |
|
|
41 |
self.weights = weights |
|
|
42 |
|
|
|
43 |
def forward(self, outputs, targets): |
|
|
44 |
loss = self.cerition(outputs, targets) |
|
|
45 |
return (loss * self.weights).mean() |