a b/chexbert/src/datasets/unlabeled_dataset.py
1
import torch
2
import pandas as pd
3
import numpy as np
4
from transformers import BertTokenizer
5
import bert_tokenizer
6
from torch.utils.data import Dataset, DataLoader
7
8
class UnlabeledDataset(Dataset):
9
        """The dataset to contain report impressions without any labels."""
10
        
11
        def __init__(self, csv_path):
12
                """ Initialize the dataset object
13
                @param csv_path (string): path to the csv file containing rhe reports. It
14
                                          should have a column named "Report Impression"
15
                """
16
                tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
17
                impressions = bert_tokenizer.get_impressions_from_csv(csv_path)
18
                self.encoded_imp = bert_tokenizer.tokenize(impressions, tokenizer)
19
20
        def __len__(self):
21
                """Compute the length of the dataset
22
23
                @return (int): size of the dataframe
24
                """
25
                return len(self.encoded_imp)
26
27
        def __getitem__(self, idx):
28
                """ Functionality to index into the dataset
29
                @param idx (int): Integer index into the dataset
30
31
                @return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is
32
                                      a LongTensor of an encoded impression. The value of 'label'
33
                                      is a LongTensor containing the labels and 'the value of
34
                                      'len' is an integer representing the length of imp's value
35
                """
36
                if torch.is_tensor(idx):
37
                        idx = idx.tolist()
38
                imp = self.encoded_imp[idx]
39
                imp = torch.LongTensor(imp)
40
                return {"imp": imp, "len": imp.shape[0]}