Switch to side-by-side view

--- a
+++ b/development/summarizer-test-pipeline/summarizer_test/summarizertest.py
@@ -0,0 +1,107 @@
+import torch
+from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+from typing import List, Union
+import json
+from tqdm import tqdm
+
+
+class ReportGenerator():
+    def __init__(self,
+                 models_names: List[str],
+                 val_contexts_path: str,
+                 report_path: str,
+                 max_lengths: list,
+                 min_lengths: list,
+                 top_k: list,
+                 penalty_l: list,
+                 no_repeat_ngram_size: list,
+                 num_return_sequences: list,
+                 ):
+        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        self.models_names = models_names
+        self.val_contexts_path = val_contexts_path
+        self.report_path = report_path
+        self.max_lengths = max_lengths
+        self.min_lengths = min_lengths
+        self.top_k = top_k
+        self.penalty_l = penalty_l
+        self.no_repeat_ngram_size = no_repeat_ngram_size
+        self.num_return_sequences = num_return_sequences
+
+    def __load_json_data(self, path: str):
+        with open(path, 'r') as json_file:
+            ds = json.load(json_file)
+        return ds
+
+    def __write_json_data(self, path: str, data: Union[list, dict]):
+        with open(path, 'w') as f:
+            json.dump(data, f)
+
+    def __init_tokenizer(self, model: str):
+        return AutoTokenizer.from_pretrained(model)
+
+    def __init_model(self, model: str):
+        return AutoModelForSeq2SeqLM.from_pretrained(model).to(self.device)
+
+    def __summerize(self, text: str, tokenizer, model, min_l, max_l, top_k, penalty_l, no_repeat_ngram_size, num_return_sequences):
+        inputs = tokenizer(
+            [text], 
+            padding="max_length",
+            truncation=True, 
+            max_length=512, 
+            return_tensors="pt",
+            )
+        input_ids = inputs.input_ids.to(self.device)
+        attention_mask = inputs.attention_mask.to(self.device)
+        output = model.generate(
+            input_ids,
+            attention_mask=attention_mask,
+            min_length=min_l,
+            max_length=max_l,
+            num_beams=1,
+            length_penalty=penalty_l,
+            early_stopping=True,
+            no_repeat_ngram_size=no_repeat_ngram_size,
+            num_return_sequences=num_return_sequences,
+            do_sample=True,
+            top_k=top_k,
+            top_p=None,
+            output_scores=True,
+            return_dict_in_generate=True,
+        )
+
+        return [tokenizer.decode(ans, skip_special_tokens=True) for ans in output[0]]
+
+    def get_report(self):
+        print(f'{self.device} is available')
+        self.__val_data = self.__load_json_data(self.val_contexts_path)
+        self.logs = []
+        for ckpt in self.models_names:
+            print(f'start "{ckpt}" model')
+
+            self.tokenizer = self.__init_tokenizer(ckpt)
+            self.model = self.__init_model(ckpt)
+            log = []
+            for context in tqdm(self.__val_data):
+                for maxl in self.max_lengths:
+                    for minl in self.min_lengths:
+                        for tk in self.top_k:
+                            for pl in self.penalty_l:
+                                for nrns in self.no_repeat_ngram_size:
+                                    for nrs in self.num_return_sequences:
+                                        summerized = self.__summerize(
+                                            context['context'], self.tokenizer, self.model, minl, maxl, tk, pl, nrns, nrs)
+                                        log.append({
+                                            'summary': summerized,
+                                            'context': context['context'],
+                                            'max_length': maxl,
+                                            'min_length': minl,
+                                            'top_k': tk,
+                                            'penalty_length': pl,
+                                            'bo_repeat_ngram_size': nrns,
+                                            'num_return_sequences': nrs,
+                                        })
+            self.logs.append({'model': ckpt, 'log': log})
+
+        self.__write_json_data(self.report_path, self.logs)
+        print(f"report saved into {self.report_path}")