Diff of /metrics.py [000000] .. [d90d15]

Switch to unified view

a b/metrics.py
1
import numpy as np
2
3
4
class Metric:
5
    def __init__(self):
6
        pass
7
8
    def __call__(self, outputs, target, loss):
9
        raise NotImplementedError
10
11
    def reset(self):
12
        raise NotImplementedError
13
14
    def value(self):
15
        raise NotImplementedError
16
17
    def name(self):
18
        raise NotImplementedError
19
20
21
class AccumulatedAccuracyMetric(Metric):
22
    """
23
    Works with classification model
24
    """
25
26
    def __init__(self):
27
        self.correct = 0
28
        self.total = 0
29
30
    def __call__(self, outputs, target, loss):
31
        pred = outputs[0].data.max(1, keepdim=True)[1]
32
        self.correct += pred.eq(target[0].data.view_as(pred)).cpu().sum()
33
        self.total += target[0].size(0)
34
        return self.value()
35
36
    def reset(self):
37
        self.correct = 0
38
        self.total = 0
39
40
    def value(self):
41
        return 100 * float(self.correct) / self.total
42
43
    def name(self):
44
        return 'Accuracy'
45
46
47
class AverageNonzeroTripletsMetric(Metric):
48
    '''
49
    Counts average number of nonzero triplets found in minibatches
50
    '''
51
52
    def __init__(self):
53
        self.values = []
54
55
    def __call__(self, outputs, target, loss):
56
        self.values.append(loss[1])
57
        return self.value()
58
59
    def reset(self):
60
        self.values = []
61
62
    def value(self):
63
        return np.mean(self.values)
64
65
    def name(self):
66
        return 'Average nonzero triplets'