Diff of /metrics.py [000000] .. [d90d15]

Switch to side-by-side view

--- a
+++ b/metrics.py
@@ -0,0 +1,66 @@
+import numpy as np
+
+
+class Metric:
+    def __init__(self):
+        pass
+
+    def __call__(self, outputs, target, loss):
+        raise NotImplementedError
+
+    def reset(self):
+        raise NotImplementedError
+
+    def value(self):
+        raise NotImplementedError
+
+    def name(self):
+        raise NotImplementedError
+
+
+class AccumulatedAccuracyMetric(Metric):
+    """
+    Works with classification model
+    """
+
+    def __init__(self):
+        self.correct = 0
+        self.total = 0
+
+    def __call__(self, outputs, target, loss):
+        pred = outputs[0].data.max(1, keepdim=True)[1]
+        self.correct += pred.eq(target[0].data.view_as(pred)).cpu().sum()
+        self.total += target[0].size(0)
+        return self.value()
+
+    def reset(self):
+        self.correct = 0
+        self.total = 0
+
+    def value(self):
+        return 100 * float(self.correct) / self.total
+
+    def name(self):
+        return 'Accuracy'
+
+
+class AverageNonzeroTripletsMetric(Metric):
+    '''
+    Counts average number of nonzero triplets found in minibatches
+    '''
+
+    def __init__(self):
+        self.values = []
+
+    def __call__(self, outputs, target, loss):
+        self.values.append(loss[1])
+        return self.value()
+
+    def reset(self):
+        self.values = []
+
+    def value(self):
+        return np.mean(self.values)
+
+    def name(self):
+        return 'Average nonzero triplets'