Diff of /utils.py [000000] .. [dcdaea]

Switch to unified view

a b/utils.py
1
# -*- coding: utf-8 -*-
2
'''
3
@time: 2019/9/12 15:16
4
5
@ author: javis
6
'''
7
import torch
8
import numpy as np
9
import time,os
10
from sklearn.metrics import f1_score
11
from torch import nn
12
13
14
def mkdirs(path):
15
    if not os.path.exists(path):
16
        os.makedirs(path)
17
18
#计算F1score
19
def calc_f1(y_true, y_pre, threshold=0.5):
20
    y_true = y_true.view(-1).cpu().detach().numpy().astype(np.int)
21
    y_pre = y_pre.view(-1).cpu().detach().numpy() > threshold
22
    return f1_score(y_true, y_pre)
23
24
#打印时间
25
def print_time_cost(since):
26
    time_elapsed = time.time() - since
27
    return '{:.0f}m{:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)
28
29
30
# 调整学习率
31
def adjust_learning_rate(optimizer, lr):
32
    for param_group in optimizer.param_groups:
33
        param_group['lr'] = lr
34
    return lr
35
36
#多标签使用类别权重
37
class WeightedMultilabel(nn.Module):
38
    def __init__(self, weights: torch.Tensor):
39
        super(WeightedMultilabel, self).__init__()
40
        self.cerition = nn.BCEWithLogitsLoss(reduction='none')
41
        self.weights = weights
42
43
    def forward(self, outputs, targets):
44
        loss = self.cerition(outputs, targets)
45
        return (loss * self.weights).mean()