--- a +++ b/metrics.py @@ -0,0 +1,30 @@ +from torchmetrics.text import BLEUScore +from torchmetrics.text import WordErrorRate +from torchmetrics.functional.text.rouge import rouge_score + + +def compute_metrics(preds, labels): + result={} + wer = WordErrorRate() + result['wer']=wer(preds,labels).item() + rouge_result=compute_rouge(preds,labels) + for k,v in rouge_result.items(): + result[k]=v.item() + labels=[[label] for i,label in enumerate(labels)] + for i in range(1,5): + bleu=BLEUScore(n_gram=i) + result[f'bleu-{i}']=bleu(preds,labels).item() + return result + + +def compute_rouge(preds, labels): + + metrics={} + for decoded_label, decoded_pred in zip(labels, preds): + metric=rouge_score(decoded_pred,decoded_label) + for key in metric.keys(): + metrics[key]=metrics.get(key, 0) + metric[key] + for key in metrics.keys(): + metrics[key]=metrics[key]/len(labels)*100 + + return metrics \ No newline at end of file