--- a +++ b/chexbert/src/datasets/unlabeled_dataset.py @@ -0,0 +1,40 @@ +import torch +import pandas as pd +import numpy as np +from transformers import BertTokenizer +import bert_tokenizer +from torch.utils.data import Dataset, DataLoader + +class UnlabeledDataset(Dataset): + """The dataset to contain report impressions without any labels.""" + + def __init__(self, csv_path): + """ Initialize the dataset object + @param csv_path (string): path to the csv file containing rhe reports. It + should have a column named "Report Impression" + """ + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + impressions = bert_tokenizer.get_impressions_from_csv(csv_path) + self.encoded_imp = bert_tokenizer.tokenize(impressions, tokenizer) + + def __len__(self): + """Compute the length of the dataset + + @return (int): size of the dataframe + """ + return len(self.encoded_imp) + + def __getitem__(self, idx): + """ Functionality to index into the dataset + @param idx (int): Integer index into the dataset + + @return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is + a LongTensor of an encoded impression. The value of 'label' + is a LongTensor containing the labels and 'the value of + 'len' is an integer representing the length of imp's value + """ + if torch.is_tensor(idx): + idx = idx.tolist() + imp = self.encoded_imp[idx] + imp = torch.LongTensor(imp) + return {"imp": imp, "len": imp.shape[0]}