Diff of /utils.py [000000] .. [dcdaea]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+'''
+@time: 2019/9/12 15:16
+
+@ author: javis
+'''
+import torch
+import numpy as np
+import time,os
+from sklearn.metrics import f1_score
+from torch import nn
+
+
+def mkdirs(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+#计算F1score
+def calc_f1(y_true, y_pre, threshold=0.5):
+    y_true = y_true.view(-1).cpu().detach().numpy().astype(np.int)
+    y_pre = y_pre.view(-1).cpu().detach().numpy() > threshold
+    return f1_score(y_true, y_pre)
+
+#打印时间
+def print_time_cost(since):
+    time_elapsed = time.time() - since
+    return '{:.0f}m{:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)
+
+
+# 调整学习率
+def adjust_learning_rate(optimizer, lr):
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+    return lr
+
+#多标签使用类别权重
+class WeightedMultilabel(nn.Module):
+    def __init__(self, weights: torch.Tensor):
+        super(WeightedMultilabel, self).__init__()
+        self.cerition = nn.BCEWithLogitsLoss(reduction='none')
+        self.weights = weights
+
+    def forward(self, outputs, targets):
+        loss = self.cerition(outputs, targets)
+        return (loss * self.weights).mean()