Diff of /PVQA/dataset.py [000000] .. [dc40d0]

Switch to unified view

a b/PVQA/dataset.py
1
from torch.utils.data import Dataset
2
import pandas as pd
3
import os
4
from PIL import Image
5
from torchvision import transforms
6
from collections import defaultdict
7
import torch
8
import pickle
9
10
class ImageTextContrastiveCollator:
11
    def __init__(self):
12
        return
13
    def __call__(self, batch):
14
        inputs = defaultdict(list)
15
        for data in batch:
16
            inputs['image'].append(data['image'])
17
            inputs['question'].append(data['question'])
18
            inputs['answer'].append(data['answer'])
19
            
20
21
        # inputs['image'] = torch.stack(inputs['image'])
22
23
        return inputs
24
pkl_path = '../PathVQA/pvqa/qas/test_vqa.pkl'
25
class PVQAdataset(Dataset):
26
    def __init__(self):
27
        # self.df = pd.read_csv(csv_path)
28
        with open(pkl_path, 'rb') as f:
29
            self.data = pickle.load(f)
30
       
31
        normalize = transforms.Normalize(
32
            (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
33
        )
34
35
        self.transform = transforms.Compose(
36
            [
37
                transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
38
                transforms.RandomHorizontalFlip(),
39
                transforms.ToTensor(),
40
                normalize,
41
            ]
42
        )
43
        
44
    def __len__(self):
45
        return len(self.data)
46
    def __getitem__(self, index):
47
        question = self.data[index]['sent']
48
        answer = list(self.data[index]['label'].keys())[0]
49
        img_path = os.path.join('../PathVQA/pvqa/images', 'test', self.data[index]['img_id'])+".jpg"
50
        return {
51
            "image": img_path,
52
            "question": question,
53
            "answer": answer,
54
        }
55
        # return {
56
        #     "image": img_path,
57
        #     "text_input": caption,
58
        #     "text_output": caption,
59
        # }
60
         
61
        
62
        
63
if __name__ == '__main__':
64
    test = PVQAdataset()
65
    print(test.__len__())
66
    print(test.__getitem__(0))
67
    print(test.__getitem__(1))
68
    
69
    
70