"""
This file will contain the metrics of the framework
"""
import numpy as np
class IOUMetric:
"""
Class to calculate mean-iou using fast_hist method
"""
def __init__(self, num_classes):
self.num_classes = num_classes
self.hist = np.zeros((num_classes, num_classes))
def _fast_hist(self, label_pred, label_true):
mask = (label_true >= 0) & (label_true < self.num_classes)
hist = np.bincount(
self.num_classes * label_true[mask].astype(int) +
label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
return hist
def add_batch(self, predictions, gts):
for lp, lt in zip(predictions, gts):
self.hist += self._fast_hist(lp.flatten(), lt.flatten())
def evaluate(self):
acc = np.diag(self.hist).sum() / self.hist.sum()
acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
mean_iu = np.nanmean(iu)
freq = self.hist.sum(axis=1) / self.hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, iu, mean_iu, fwavacc
class AverageMeter:
"""
Class to be an average meter for any average metric like loss, accuracy, etc..
"""
def __init__(self):
self.value = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.value = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.value = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
@property
def val(self):
return self.avg
class AverageMeterList:
"""
Class to be an average meter for any average metric List structure like mean_iou_per_class
"""
def __init__(self, num_cls):
self.cls = num_cls
self.value = [0] * self.cls
self.avg = [0] * self.cls
self.sum = [0] * self.cls
self.count = [0] * self.cls
self.reset()
def reset(self):
self.value = [0] * self.cls
self.avg = [0] * self.cls
self.sum = [0] * self.cls
self.count = [0] * self.cls
def update(self, val, n=1):
for i in range(self.cls):
self.value[i] = val[i]
self.sum[i] += val[i] * n
self.count[i] += n
self.avg[i] = self.sum[i] / self.count[i]
@property
def val(self):
return self.avg
def cls_accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k / batch_size)
return res