Diff of /metrics.py [000000] .. [66af30]

Switch to unified view

a b/metrics.py
1
from torchmetrics.text import BLEUScore
2
from torchmetrics.text import WordErrorRate
3
from torchmetrics.functional.text.rouge import rouge_score
4
5
6
def compute_metrics(preds, labels):
7
    result={}
8
    wer = WordErrorRate()
9
    result['wer']=wer(preds,labels).item()
10
    rouge_result=compute_rouge(preds,labels)
11
    for k,v in rouge_result.items():
12
        result[k]=v.item()
13
    labels=[[label] for i,label in enumerate(labels)]
14
    for i in range(1,5):
15
        bleu=BLEUScore(n_gram=i)
16
        result[f'bleu-{i}']=bleu(preds,labels).item()
17
    return result
18
19
20
def compute_rouge(preds, labels):
21
22
    metrics={}
23
    for decoded_label, decoded_pred in zip(labels, preds):
24
        metric=rouge_score(decoded_pred,decoded_label)
25
        for key in metric.keys():
26
            metrics[key]=metrics.get(key, 0) + metric[key]
27
    for key in metrics.keys():
28
        metrics[key]=metrics[key]/len(labels)*100
29
30
    return metrics