import torch
import pandas as pd
import numpy as np
from transformers import BertTokenizer
import bert_tokenizer
from torch.utils.data import Dataset, DataLoader
class UnlabeledDataset(Dataset):
"""The dataset to contain report impressions without any labels."""
def __init__(self, csv_path):
""" Initialize the dataset object
@param csv_path (string): path to the csv file containing rhe reports. It
should have a column named "Report Impression"
"""
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
impressions = bert_tokenizer.get_impressions_from_csv(csv_path)
self.encoded_imp = bert_tokenizer.tokenize(impressions, tokenizer)
def __len__(self):
"""Compute the length of the dataset
@return (int): size of the dataframe
"""
return len(self.encoded_imp)
def __getitem__(self, idx):
""" Functionality to index into the dataset
@param idx (int): Integer index into the dataset
@return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is
a LongTensor of an encoded impression. The value of 'label'
is a LongTensor containing the labels and 'the value of
'len' is an integer representing the length of imp's value
"""
if torch.is_tensor(idx):
idx = idx.tolist()
imp = self.encoded_imp[idx]
imp = torch.LongTensor(imp)
return {"imp": imp, "len": imp.shape[0]}