a b/green3.py
1
import re
2
import torch
3
import torch.distributed as dist
4
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
5
import pandas as pd
6
from datasets import Dataset
7
from datasets.distributed import split_dataset_by_node
8
import os
9
from tqdm import tqdm
10
import numpy as np
11
import time
12
from .utils import (
13
    gather_processes,
14
    make_prompt,
15
    clean_responses,
16
    compute_largest_cluster,
17
    flatten_values_lists_of_list_dicts_to_dict,
18
)
19
import sys
20
import warnings
21
22
def truncate_to_max_len(sentences, max_len):
23
    return [" ".join(sentence.split()[:max_len]) for sentence in sentences]
24
25
def get_rank():
26
    if not dist.is_initialized():
27
        return 0
28
    return dist.get_rank()
29
30
def is_main_process():
31
    return get_rank() == 0
32
33
def tqdm_on_main(*args, **kwargs):
34
    if is_main_process():
35
        print("==== Beginning Inference ====")
36
        return tqdm(*args, **kwargs)
37
    else:
38
        return kwargs.get('iterable', None)
39
40
class Inferer:
41
    def __init__(
42
        self,
43
        dataset=None,
44
        model=None,
45
        tokenizer=None,
46
        model_name="",
47
        output_dir=".",
48
        num_examples=None,
49
        batch_size=10,
50
        max_length=2048,
51
    ):
52
53
        self.dataset = Dataset.from_dict(
54
            {"reference": dataset[0], "prediction": dataset[1]}
55
        )
56
        self.process_data()
57
58
        self.model = model
59
        self.model_name = model_name.split("/")[-1] 
60
        self.tokenizer = tokenizer
61
        self.num_examples = num_examples
62
63
        self.output_dir = output_dir
64
65
        self.batch_size = batch_size
66
67
        self.prompts = None
68
        self.completions = None
69
        self.green_scores = None
70
        self.error_counts = None
71
72
        self.categories = [
73
            "Clinically Significant Errors",
74
            "Clinically Insignificant Errors",
75
            "Matched Findings",
76
        ]
77
78
        self.sub_categories = [
79
            "(a) False report of a finding in the candidate",
80
            "(b) Missing a finding present in the reference",
81
            "(c) Misidentification of a finding's anatomic location/position",
82
            "(d) Misassessment of the severity of a finding",
83
            "(e) Mentioning a comparison that isn't in the reference",
84
            "(f) Omitting a comparison detailing a change from a prior study",
85
        ]
86
87
        self.max_length = max_length
88
89
    def process_data(self):
90
        print("Processing data...making prompts")
91
92
        def promting(examples):
93
            return {
94
                "prompt": [
95
                    make_prompt(r, p)
96
                    for r, p in zip(examples["reference"], examples["prediction"])
97
                ]
98
            }
99
100
        self.dataset = self.dataset.map(promting, batched=True)
101
        print("Done.")
102
103
    @torch.inference_mode()
104
    def infer(self):
105
106
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
107
            dataset_dist = split_dataset_by_node(
108
                self.dataset,
109
                rank=get_rank(),
110
                world_size=int(os.environ["WORLD_SIZE"]),
111
            )
112
            print("Distributed dataset created on rank: ", int(os.environ["RANK"]))
113
        else:
114
            dataset_dist = self.dataset
115
116
        local_completions = []
117
        local_references = []
118
119
        for batch in tqdm_on_main(
120
            iterable=dataset_dist.iter(batch_size=self.batch_size),
121
            total=len(dataset_dist) // self.batch_size,
122
        ):
123
            local_references.extend(batch["prompt"])
124
            local_completions.extend(self.get_response(batch))
125
126
        # gather results if multi gpu and single gpu settings
127
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
128
            self.completions, self.prompts = gather_processes(
129
                local_completions, local_references
130
            )
131
        else:
132
            self.completions = local_completions
133
            self.prompts = local_references
134
135
        if is_main_process():
136
            print("==== End Inference ====")
137
138
        if len(self.completions) != len(self.prompts):
139
            print("length of prompts and completions are not equal!")
140
141
        self.process_results()
142
143
    def tokenize_batch_as_chat(self, batch):
144
145
        batch = [
146
            self.tokenizer.apply_chat_template(
147
                i, tokenize=False, add_generation_prompt=True
148
            )
149
            for i in batch["conv"]
150
        ]
151
152
        # tokenization
153
        batch = self.tokenizer.batch_encode_plus(
154
            batch,
155
            return_tensors="pt",
156
            padding=True,
157
            truncation=True,
158
            max_length=self.max_length,
159
        ).to(int(os.environ.get("LOCAL_RANK", 0)))
160
161
        return batch
162
163
    def get_response(self, batch):
164
165
        # format batch
166
        assert "prompt" in batch.keys(), "prompt is not in batch keys"
167
168
        batch["conv"] = [
169
            [
170
                {"from": "human", "value": i},
171
            ]
172
            for i in batch["prompt"]
173
        ]
174
        # batch = [[{"from": "human", "value": prompt}] for prompt in batch['prompt']]
175
        batch = self.tokenize_batch_as_chat(batch)
176
177
        outputs = self.model.generate(
178
            **batch,
179
            eos_token_id=self.tokenizer.eos_token_id,
180
            pad_token_id=self.tokenizer.pad_token_id,
181
            generation_config=GenerationConfig(
182
                max_new_tokens=self.max_length,
183
                do_sample=False,
184
            )
185
        )
186
187
        # # decode response
188
        responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
189
190
        # reformat the responses
191
        response_list = []
192
        if isinstance(responses, list):
193
            for response in responses:
194
                response = clean_responses(response)
195
                response_list.append(response)
196
        else:
197
            responses = clean_responses(responses)
198
            response_list.append(responses)
199
200
        return response_list
201
202
    def process_results(self):
203
204
        self.green_scores = [
205
            self.compute_green(response) for response in self.completions
206
        ]
207
        self.error_counts = pd.DataFrame(
208
            [self.compute_error_count(response) for response in self.completions],
209
            columns=self.sub_categories + ["Matched Findings"],
210
        )
211
212
        results_df = pd.DataFrame(
213
            {
214
                "reference": self.dataset["reference"],
215
                "predictions": self.dataset["prediction"],
216
                "evaluation": self.completions,
217
                "green": self.green_scores,
218
                **self.error_counts,  # unpacking the dictionary
219
            }
220
        )
221
        path = self.output_dir + f"/results_{self.model_name}.csv"
222
        os.makedirs(self.output_dir, exist_ok=True)
223
        print("Saving generated response to prompt to ", path)
224
        results_df.to_csv(path, index=False)
225
226
        summ= self.compute_summary()
227
228
        # saving summary to csv
229
        path = self.output_dir + f"/resultsSummary_{self.model_name}.txt"
230
        os.makedirs(self.output_dir, exist_ok=True)
231
        print("Saving generated Summary to prompt to ", path)
232
        with open(path, 'w') as file:
233
            file.write(summ)
234
        #summ.to_csv(path, index=False)
235
236
        return results_df
237
238
    def compute_error_count(self, response):
239
        _, sig_errors = self.parse_error_counts(response, self.categories[0])
240
        # matched findings, we want to look at the sum of all errors
241
        matched_findings, _ = self.parse_error_counts(response, self.categories[2])
242
        return sig_errors + [matched_findings]
243
244
    def compute_green(self, response):
245
        # significant clinical errors, we want to look at each error type
246
        sig_present, sig_errors = self.parse_error_counts(response, self.categories[0])
247
        # matched findings, we want to look at the sum of all errors
248
        matched_findings, _ = self.parse_error_counts(response, self.categories[2])
249
250
        # set the prior study (sub_categories: (e) Mentioning a comparison that isn't in the reference, (f) Omitting a comparison detailing a change from a prior study) errors to 0
251
        # Note: we are NOT doing this anymore: sig_errors[-2:] = 0, 0
252
253
        if matched_findings == 0:
254
            return 0
255
256
        if (
257
            sig_present is None or matched_findings is None
258
        ):  # when the template does not include the key "Clinically Significant Errors"
259
            return None
260
261
        return matched_findings / (matched_findings + sum(sig_errors))
262
263
    def parse_error_counts(self, text, category, for_reward=False):
264
265
        if category not in self.categories:
266
            raise ValueError(
267
                f"Category {category} is not a valid category. Please choose from {self.categories}."
268
            )
269
270
        # Pattern to match integers within the category, stopping at the next category or end of text
271
        pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
272
        category_text = re.search(pattern, text, re.DOTALL)
273
274
        # Initialize the counts
275
        sum_counts = 0
276
        sub_counts = [0 for i in range(6)]
277
278
        # If the category is not found, return 0
279
        if not category_text:
280
            if for_reward:
281
                # we need to know whether the category is empty or not, otherwise we overesitmate the reward
282
                return None, None
283
            return sum_counts, sub_counts
284
        # If the category is found, but the category is empty, return 0
285
        if category_text.group(1).startswith("No"):
286
            return sum_counts, sub_counts
287
288
        if category == "Matched Findings":
289
            counts = re.findall(r"^\b\d+\b(?=\.)", category_text.group(1))
290
            if len(counts) > 0:
291
                sum_counts = int(counts[0])
292
            return sum_counts, sub_counts
293
        # Possible fine-grained error categories for categories Significant and Insignificant Clinical Errors
294
        else:  # "Clinically Significant Errors" or "Clinically Insignificant Errors"
295
            # Split each string at the first space and keep only the first part
296
            sub_categories = [s.split(" ", 1)[0] + " " for s in self.sub_categories]
297
            # Find all sub_categories in the matched text
298
            matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))
299
300
            # this is for the gpt-4 template which assigns a number to the subcategories not letters
301
            if len(matches) == 0:
302
                matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
303
                sub_categories = [
304
                    f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
305
                ]
306
307
            for position, sub_category in enumerate(sub_categories):
308
                # need to loop over all matches, because the sub_categories are not always in the same order
309
                for match in range(len(matches)):
310
                    if matches[match].startswith(sub_category):
311
                        # If the sub_category is found, insert the count to sub_counts at the ordered position
312
                        count = re.findall(r"(?<=: )\b\d+\b(?=\.)", matches[match])
313
                        if len(count) > 0:
314
                            # take the first number after the colon
315
                            sub_counts[position] = int(count[0])
316
            return sum(sub_counts), sub_counts
317
318
    def parse_error_sentences(self, response, category):
319
        """
320
        Parses error sentences from a given response based of the specified category. Extracts sentences associated with each sub-categories and returns them in a dict format.
321
322
        Args:
323
            text (str): The input text containing error information.
324
            category (str): The category to parse within the text.
325
326
        Returns:
327
            dict: A dictionary where keys are sub-categories and values are lists of sentences associated with those sub-categories. If the category is "Matched Findings", returns a list of sentences directly.
328
        """
329
        if category not in self.categories:
330
            raise ValueError(
331
                f"Category {category} is not a valid category. Please choose from {self.categories}."
332
            )
333
        pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
334
        category_text = re.search(pattern, response, re.DOTALL)
335
        sub_category_dict_sentences = {}
336
        for sub_category in self.sub_categories:
337
            sub_category_dict_sentences[sub_category] = []
338
339
        if not category_text:
340
            return sub_category_dict_sentences
341
        if category_text.group(1).startswith("No"):
342
            return sub_category_dict_sentences
343
344
        if category == "Matched Findings":
345
            return (
346
                category_text.group(1).rsplit(":", 1)[-1].rsplit(".", 1)[-1].split(";")
347
            )
348
349
        matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))
350
351
        if len(matches) == 0:
352
            matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
353
            self.sub_categories = [
354
                f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
355
            ]
356
357
        for position, sub_category in enumerate(self.sub_categories):
358
            # need to loop over all matches, because the sub_categories are not always in the same order
359
            for match in range(len(matches)):
360
                if matches[match].startswith(sub_category):
361
                    # If the sub_category is found, add to dictionary
362
                    sentences_list = (
363
                        matches[match].rsplit(":", 1)[-1].split(".", 1)[-1].split(";")
364
                    )
365
                    sub_category_dict_sentences[self.sub_categories[position]] = (
366
                        sentences_list
367
                    )
368
369
        return sub_category_dict_sentences
370
371
    def compute_sentences(self, response):
372
        # for now we only look at the significant clinical errors, which is the first category
373
        return self.parse_error_sentences(response, self.categories[0])
374
375
    def get_representative_sentences(self, responses):
376
        list_sentences = []
377
        for i in responses:
378
            sentences = self.compute_sentences(i)
379
            list_sentences.append(sentences)
380
381
        dict_sentences = flatten_values_lists_of_list_dicts_to_dict(list_sentences)
382
383
        result_sentences_dict = {}
384
385
        for i in self.sub_categories:
386
            sentences = dict_sentences[i]
387
            sentences = [i for i in sentences if i.strip() != ""]
388
            _, sentences_of_largest_cluster = compute_largest_cluster(sentences)
389
            result_sentences_dict[i] = sentences_of_largest_cluster
390
391
        return result_sentences_dict
392
393
    def compute_accuracy(self, responses):
394
        """
395
        Computes the accuracy for each subcategory based on significant clinical errors and matched findings.
396
397
        Args:
398
            responses (list): Generated responses to evaluate.
399
400
        Returns:
401
            dict: accurarcies per subcategory.
402
        """
403
        counts = []
404
        for response in responses:
405
            _, sig_errors = self.parse_error_counts(response, self.categories[0])
406
            counts.append(sig_errors)
407
408
        counts = np.array(counts)
409
410
        dict_acc = {}
411
        for i in range(len(self.sub_categories)):
412
            error_counts = counts[:, i]
413
            # compute the accuracy for each subcategory
414
            accuracy = np.mean(error_counts == 0)
415
            dict_acc[self.sub_categories[i]] = accuracy
416
417
        return dict_acc
418
419
    def compute_summary(self):
420
        """
421
        Makes green summary.
422
423
        Args:
424
            mean_green (int): grean average.
425
            mean_std (int): grean std.
426
            responses (list): list of green model responses (str)
427
428
        Returns:
429
            str: green summary.
430
        """
431
        print("Computing summary ...")
432
        #representative_sentences = self.get_representative_sentences(self.completions)
433
        accuracies = self.compute_accuracy(self.completions)
434
435
        #summary = f"\n-------------{self.model_name}----------------\n [Summary]: Green average {np.mean(self.green_scores)} and standard variation {np.std(self.green_scores)} \n [Clinically Significant Errors Analyses]: <accuracy>. <representative error>\n\n (a) False report of a finding in the candidate: {accuracies[self.sub_categories[0]]}. \n {representative_sentences[self.sub_categories[0]]} \n\n (b) Missing a finding present in the reference: {accuracies[self.sub_categories[1]]}. \n {representative_sentences[self.sub_categories[1]]} \n\n (c) Misidentification of a finding's anatomic location/position: {accuracies[self.sub_categories[2]]}. \n {representative_sentences[self.sub_categories[2]]} \n\n (d) Misassessment of the severity of a finding: {accuracies[self.sub_categories[3]]}. \n {representative_sentences[self.sub_categories[3]]} \n\n (e) Mentioning a comparison that isn't in the reference: {accuracies[self.sub_categories[4]]}. \n {representative_sentences[self.sub_categories[4]]} \n\n (f) Omitting a comparison detailing a change from a prior study: {accuracies[self.sub_categories[5]]}. {representative_sentences[self.sub_categories[5]]}.\n----------------------------------\n"
436
        summary = f"\n-------------{self.model_name}----------------\n [Summary]: Green average {np.mean(self.green_scores)} and standard variation {np.std(self.green_scores)} \n [Clinically Significant Errors Analyses]: <accuracy>. <representative error>\n\n (a) False report of a finding in the candidate: {accuracies[self.sub_categories[0]]}. \n\n\n (b) Missing a finding present in the reference: {accuracies[self.sub_categories[1]]}. \n\n\n (c) Misidentification of a finding's anatomic location/position: {accuracies[self.sub_categories[2]]}. \n\n\n (d) Misassessment of the severity of a finding: {accuracies[self.sub_categories[3]]}. \n\n\n (e) Mentioning a comparison that isn't in the reference: {accuracies[self.sub_categories[4]]}. \n\n\n (f) Omitting a comparison detailing a change from a prior study: {accuracies[self.sub_categories[5]]}.\n----------------------------------\n"
437
438
        print(summary)
439
        return summary
440
441
442
def compute(model_name, refs, hyps, output_dir="."):
443
    warnings.filterwarnings("ignore", message="A decoder-only architecture is being used*") # this warning appears, despide 'padding_side='left' and correct padding
444
    from sklearn.exceptions import ConvergenceWarning
445
    warnings.filterwarnings("ignore", category=ConvergenceWarning, message="Number of distinct clusters.*") # test examples are copied
446
    warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.tokenization_utils_base")
447
448
    
449
    chat_template = "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n'  + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
450
451
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
452
        if not dist.is_initialized():
453
            dist.init_process_group(
454
                backend="nccl",
455
            )  # 'nccl' is recommended for GPUs
456
            torch.cuda.set_device(dist.get_rank())
457
            if dist.get_rank() == 0:
458
                print("Distributed training with", torch.cuda.device_count(), "GPUs")
459
460
461
    model = AutoModelForCausalLM.from_pretrained(
462
        model_name,
463
        trust_remote_code=False if "Phi" in model_name else True,
464
        device_map={"": "cuda:{}".format(torch.cuda.current_device())},
465
        torch_dtype=torch.float16,
466
    )
467
        
468
    model.eval()
469
470
    tokenizer = AutoTokenizer.from_pretrained(
471
        model_name,
472
        add_eos_token=True,
473
        use_fast=True,
474
        trust_remote_code=True,
475
        padding_side="left",
476
    )
477
    
478
    tokenizer.chat_template = chat_template
479
    tokenizer.pad_token = tokenizer.eos_token
480
    tokenizer.clean_up_tokenization_spaces = True
481
482
    inferer = Inferer(
483
        dataset=[refs, hyps],
484
        model=model,
485
        model_name=model_name,
486
        tokenizer=tokenizer,
487
        output_dir=output_dir,
488
        batch_size=12, # 16
489
    )
490
491
    t = time.time()
492
493
    inferer.infer()
494
495
    t = time.time() - t
496
    print("Seconds per example: ", t / len(refs))
497
    
498
    if not is_main_process():
499
        # Exit the process
500
        print(f"Rank {dist.get_rank()} exiting.")
501
        dist.destroy_process_group()  # Clean up the distributed processing group
502
        sys.exit()  # Exit the process
503
504
if __name__ == "__main__":
505
    import time
506
507
    refs = [
508
        "Interstitial opacities without changes.",
509
        "Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
510
        "Bibasilar atelectasis. Otherwise, no acute intrathoracic process.",
511
        "Lung volumes are low, causing bronchovascular crowding. The cardiomediastinal silhouette is unremarkable. No focal consolidation, pleural effusion, or pneumothorax detected. Within the limitations of chest radiography, osseous structures are unremarkable.",
512
        "Interval resolution of previously seen mild pulmonary edema with trace bilateral pleural effusions.",
513
        "Lung volumes are low, causing bronchovascular crowding. The cardiomediastinal silhouette is unremarkable. No focal consolidation, pleural effusion, or pneumothorax detected. Within the limitations of chest radiography, osseous structures are unremarkable.",
514
        "Bilateral pleural effusions, large on the right and small on the left. No definite focal consolidation identified, although evaluation is limited secondary to these effusions.",
515
        "1. Mild left basal atelectasis. Otherwise unremarkable. 2. No definite displaced rib fracture though if there is continued concern dedicated rib series may be performed to further assess.",
516
        "Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
517
    ]
518
    hyps = [
519
        "Interstitial opacities at bases without changes.",
520
        "Interval resolution of previously seen mild pulmonary edema with trace bilateral pleural effusions.",
521
        "Bibasilar atelectasis. Otherwise, no acute intrathoracic process.",
522
        "Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
523
        "Endotracheal and nasogastric tubes have been removed. Changes of median sternotomy, with continued leftward displacement of the fourth inferiomost sternal wire. There is continued moderate-to-severe enlargement of the cardiac silhouette. Pulmonary aeration is slightly improved, with residual left lower lobe atelectasis. Stable central venous congestion and interstitial pulmonary edema. Small bilateral pleural effusions are unchanged.",
524
        "Endotracheal and nasogastric tubes have been removed. Changes of median sternotomy, with continued leftward displacement of the fourth inferiomost sternal wire. There is continued moderate-to-severe enlargement of the cardiac silhouette. Pulmonary aeration is slightly improved, with residual left lower lobe atelectasis. Stable central venous congestion and interstitial pulmonary edema. Small bilateral pleural effusions are unchanged.",
525
        "In comparison with the study of ___, the increased opacification at the right base has essentially cleared with better inspiration. Cardiac silhouette remains at the upper limits of normal in size and there is again tortuosity of the aorta without vascular congestion or pleural effusion. Biapical changes, especially on the right, are stable.",
526
        "1. Mild left basal atelectasis. Otherwise unremarkable.",
527
        "1. Mild left basal atelectasis. Otherwise unremarkable. 2. No definite displaced rib fracture though if there is continued concern dedicated rib series may be performed to further assess.",
528
    ]
529
530
    model_name = "StanfordAIMI/GREEN-radllama2-7b"
531
532
    compute(model_name, refs, hyps, output_dir=".")