|
a |
|
b/metrics.py |
|
|
1 |
from torchmetrics.text import BLEUScore |
|
|
2 |
from torchmetrics.text import WordErrorRate |
|
|
3 |
from torchmetrics.functional.text.rouge import rouge_score |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def compute_metrics(preds, labels): |
|
|
7 |
result={} |
|
|
8 |
wer = WordErrorRate() |
|
|
9 |
result['wer']=wer(preds,labels).item() |
|
|
10 |
rouge_result=compute_rouge(preds,labels) |
|
|
11 |
for k,v in rouge_result.items(): |
|
|
12 |
result[k]=v.item() |
|
|
13 |
labels=[[label] for i,label in enumerate(labels)] |
|
|
14 |
for i in range(1,5): |
|
|
15 |
bleu=BLEUScore(n_gram=i) |
|
|
16 |
result[f'bleu-{i}']=bleu(preds,labels).item() |
|
|
17 |
return result |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def compute_rouge(preds, labels): |
|
|
21 |
|
|
|
22 |
metrics={} |
|
|
23 |
for decoded_label, decoded_pred in zip(labels, preds): |
|
|
24 |
metric=rouge_score(decoded_pred,decoded_label) |
|
|
25 |
for key in metric.keys(): |
|
|
26 |
metrics[key]=metrics.get(key, 0) + metric[key] |
|
|
27 |
for key in metrics.keys(): |
|
|
28 |
metrics[key]=metrics[key]/len(labels)*100 |
|
|
29 |
|
|
|
30 |
return metrics |