a b/minigpt4/datasets/datasets/mimic_dataset.py
1
import os
2
import json
3
import re
4
from PIL import Image
5
import webdataset as wds
6
import random
7
from torch.utils.data import Dataset
8
from minigpt4.datasets.datasets.base_dataset import BaseDataset
9
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
10
11
12
class MIMICDataset(Dataset):
13
    def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None):
14
        self.image_root = image_root
15
        self.ann_path = ann_path
16
        
17
        self.vis_processor = vis_processor
18
        self.text_processor = text_processor
19
        
20
        # load annotation file
21
        with open(ann_path, 'r') as f:
22
            self.annotations = json.load(f)
23
        self.train_data = self.annotations['train']
24
       
25
    def __len__(self):
26
        return len(self.train_data)
27
        
28
    def __getitem__(self, index):
29
        data_sample = self.train_data[index]
30
        image_path = data_sample['image_path']
31
        
32
        # load image
33
        image_id = data_sample['id']
34
        image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
35
        image = self.vis_processor(image)
36
        
37
        # load caption
38
        caption = data_sample['report']
39
        caption = self.clean_reports(caption)
40
        
41
        return {"image": image,
42
                "text_input": caption,
43
                "image_id": image_id}
44
        
45
    def clean_reports(self, report):
46
        report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
47
            .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace('  ', ' ') \
48
            .replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ') \
49
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
50
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
51
            .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
52
            .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
53
            .strip().lower().split('. ')
54
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
55
                                        .replace('\\', '').replace("'", '').strip().lower())
56
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
57
        report = ' . '.join(tokens) + ' .'
58
        return report
59
        
60
class MIMICGenerateThenRefineDataset(Dataset):
61
    def __init__(self, vis_processor=None, text_processor=None, image_root=None, ann_path=None, unlabeled_ann_path=None, retrieval_size=3):
62
        self.image_root = image_root
63
        self.ann_path = ann_path
64
        self.retrieval_size = retrieval_size
65
        
66
        self.vis_processor = vis_processor
67
        self.text_processor = text_processor
68
        
69
        # load annotation file
70
        with open(ann_path, 'r') as f:
71
            self.annotations = json.load(f)
72
        self.train_data = self.annotations['train']
73
       
74
        # load unlabeled data
75
        self.unlabeled_data_list = []
76
        with open(unlabeled_ann_path, 'r') as f:
77
            for line in f.readlines:
78
                self.unlabeled_data_list.append(line.strip('\n'))
79
            
80
        import random
81
        self.unlabeled_data_list = random.sample(self.unlabeled_data_list, 3000)
82
            
83
        print(f"There are total {len(self.unlabeled_data_list)} unlabeled reports.")
84
       
85
    def __len__(self):
86
        return len(self.train_data)
87
        
88
    def __getitem__(self, index):
89
        data = self.train_data[index]
90
        data_samples = random.sample(self.train_data, self.retrieval_size - 1)
91
        image_path = data['image_path']
92
        
93
        # load image
94
        image_id = data['id']
95
        image = Image.open(os.path.join(self.image_root, image_path[0])).convert('RGB')
96
        image = self.vis_processor(image)
97
        
98
        # load caption
99
        caption = data['report']
100
        caption = self.clean_reports(caption)
101
        
102
        # load reference caption
103
        all_ref_captions = []
104
        ref_caption = data['ref_report']
105
        ref_caption = self.clean_reports(ref_caption)
106
        all_ref_captions.append(ref_caption)
107
        
108
        for data_sample in data_samples:
109
            ref_caption = data_sample['ref_report']
110
            ref_caption = self.clean_reports(ref_caption)
111
            all_ref_captions.append(ref_caption)
112
        
113
        # load unlabeled caption
114
        unlabeled_caption = random.sample(self.unlabeled_data_list, self.retrieval_size)
115
        
116
        return {"image": image,
117
                "text_input": caption,
118
                "ref_caption": ref_caption,
119
                "unlabeled_caption": unlabeled_caption,
120
                "image_id": image_id}
121
        
122
    def clean_reports(self, report):
123
        report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
124
            .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace('  ', ' ') \
125
            .replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ') \
126
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
127
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
128
            .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
129
            .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
130
            .strip().lower().split('. ')
131
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
132
                                        .replace('\\', '').replace("'", '').strip().lower())
133
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
134
        report = ' . '.join(tokens) + ' .'
135
        return report
136
    
137