[4fa73e]: / pytorch / utils / metrics.py

Download this file

111 lines (88 with data), 3.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
"""
This file will contain the metrics of the framework
"""
import numpy as np
class IOUMetric:
"""
Class to calculate mean-iou using fast_hist method
"""
def __init__(self, num_classes):
self.num_classes = num_classes
self.hist = np.zeros((num_classes, num_classes))
def _fast_hist(self, label_pred, label_true):
mask = (label_true >= 0) & (label_true < self.num_classes)
hist = np.bincount(
self.num_classes * label_true[mask].astype(int) +
label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
return hist
def add_batch(self, predictions, gts):
for lp, lt in zip(predictions, gts):
self.hist += self._fast_hist(lp.flatten(), lt.flatten())
def evaluate(self):
acc = np.diag(self.hist).sum() / self.hist.sum()
acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
mean_iu = np.nanmean(iu)
freq = self.hist.sum(axis=1) / self.hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, iu, mean_iu, fwavacc
class AverageMeter:
"""
Class to be an average meter for any average metric like loss, accuracy, etc..
"""
def __init__(self):
self.value = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.value = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.value = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
@property
def val(self):
return self.avg
class AverageMeterList:
"""
Class to be an average meter for any average metric List structure like mean_iou_per_class
"""
def __init__(self, num_cls):
self.cls = num_cls
self.value = [0] * self.cls
self.avg = [0] * self.cls
self.sum = [0] * self.cls
self.count = [0] * self.cls
self.reset()
def reset(self):
self.value = [0] * self.cls
self.avg = [0] * self.cls
self.sum = [0] * self.cls
self.count = [0] * self.cls
def update(self, val, n=1):
for i in range(self.cls):
self.value[i] = val[i]
self.sum[i] += val[i] * n
self.count[i] += n
self.avg[i] = self.sum[i] / self.count[i]
@property
def val(self):
return self.avg
def cls_accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k / batch_size)
return res