--- a +++ b/utils.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +''' +@time: 2019/9/12 15:16 + +@ author: javis +''' +import torch +import numpy as np +import time,os +from sklearn.metrics import f1_score +from torch import nn + + +def mkdirs(path): + if not os.path.exists(path): + os.makedirs(path) + +#计算F1score +def calc_f1(y_true, y_pre, threshold=0.5): + y_true = y_true.view(-1).cpu().detach().numpy().astype(np.int) + y_pre = y_pre.view(-1).cpu().detach().numpy() > threshold + return f1_score(y_true, y_pre) + +#打印时间 +def print_time_cost(since): + time_elapsed = time.time() - since + return '{:.0f}m{:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60) + + +# 调整学习率 +def adjust_learning_rate(optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + +#多标签使用类别权重 +class WeightedMultilabel(nn.Module): + def __init__(self, weights: torch.Tensor): + super(WeightedMultilabel, self).__init__() + self.cerition = nn.BCEWithLogitsLoss(reduction='none') + self.weights = weights + + def forward(self, outputs, targets): + loss = self.cerition(outputs, targets) + return (loss * self.weights).mean()