--- a +++ b/HINT/protocol_encode.py @@ -0,0 +1,149 @@ +''' +input: + data/raw_data.csv + +output: + data/sentence2embedding.pkl (preprocessing) + protocol_embedding +''' + +import csv, pickle +from functools import reduce +from tqdm import tqdm +import torch +torch.manual_seed(0) +from torch import nn +import torch.nn.functional as F + +def clean_protocol(protocol): + protocol = protocol.lower() + protocol_split = protocol.split('\n') + filter_out_empty_fn = lambda x: len(x.strip())>0 + strip_fn = lambda x:x.strip() + protocol_split = list(filter(filter_out_empty_fn, protocol_split)) + protocol_split = list(map(strip_fn, protocol_split)) + return protocol_split + +def get_all_protocols(): + input_file = 'data/raw_data.csv' + with open(input_file, 'r') as csvfile: + rows = list(csv.reader(csvfile, delimiter = ','))[1:] + protocols = [row[9] for row in rows] + return protocols + +def split_protocol(protocol): + protocol_split = clean_protocol(protocol) + inclusion_idx, exclusion_idx = len(protocol_split), len(protocol_split) + for idx, sentence in enumerate(protocol_split): + if "inclusion" in sentence: + inclusion_idx = idx + break + for idx, sentence in enumerate(protocol_split): + if "exclusion" in sentence: + exclusion_idx = idx + break + if inclusion_idx + 1 < exclusion_idx + 1 < len(protocol_split): + inclusion_criteria = protocol_split[inclusion_idx:exclusion_idx] + exclusion_criteria = protocol_split[exclusion_idx:] + if not (len(inclusion_criteria) > 0 and len(exclusion_criteria) > 0): + print(len(inclusion_criteria), len(exclusion_criteria), len(protocol_split)) + exit() + return inclusion_criteria, exclusion_criteria ## list, list + else: + return protocol_split, + +def collect_cleaned_sentence_set(): + protocol_lst = get_all_protocols() + cleaned_sentence_lst = [] + for protocol in protocol_lst: + result = split_protocol(protocol) + cleaned_sentence_lst.extend(result[0]) + if len(result)==2: + cleaned_sentence_lst.extend(result[1]) + return set(cleaned_sentence_lst) + + +def save_sentence_bert_dict_pkl(): + cleaned_sentence_set = collect_cleaned_sentence_set() + from biobert_embedding.embedding import BiobertEmbedding + biobert = BiobertEmbedding() + def text2vec(text): + return biobert.sentence_vector(text) + protocol_sentence_2_embedding = dict() + for sentence in tqdm(cleaned_sentence_set): + protocol_sentence_2_embedding[sentence] = text2vec(sentence) + pickle.dump(protocol_sentence_2_embedding, open('data/sentence2embedding.pkl', 'wb')) + return + +def load_sentence_2_vec(): + sentence_2_vec = pickle.load(open('data/sentence2embedding.pkl', 'rb')) + return sentence_2_vec + +def protocol2feature(protocol, sentence_2_vec): + result = split_protocol(protocol) + inclusion_criteria, exclusion_criteria = result[0], result[-1] + inclusion_feature = [sentence_2_vec[sentence].view(1,-1) for sentence in inclusion_criteria if sentence in sentence_2_vec] + exclusion_feature = [sentence_2_vec[sentence].view(1,-1) for sentence in exclusion_criteria if sentence in sentence_2_vec] + if inclusion_feature == []: + inclusion_feature = torch.zeros(1,768) + else: + inclusion_feature = torch.cat(inclusion_feature, 0) + if exclusion_feature == []: + exclusion_feature = torch.zeros(1,768) + else: + exclusion_feature = torch.cat(exclusion_feature, 0) + return inclusion_feature, exclusion_feature + + +class Protocol_Embedding(nn.Sequential): + def __init__(self, output_dim, highway_num, device ): + super(Protocol_Embedding, self).__init__() + self.input_dim = 768 + self.output_dim = output_dim + self.highway_num = highway_num + self.fc = nn.Linear(self.input_dim*2, output_dim) + self.f = F.relu + self.device = device + self = self.to(device) + + def forward_single(self, inclusion_feature, exclusion_feature): + ## inclusion_feature, exclusion_feature: xxx,768 + inclusion_feature = inclusion_feature.to(self.device) + exclusion_feature = exclusion_feature.to(self.device) + inclusion_vec = torch.mean(inclusion_feature, 0) + inclusion_vec = inclusion_vec.view(1,-1) + exclusion_vec = torch.mean(exclusion_feature, 0) + exclusion_vec = exclusion_vec.view(1,-1) + return inclusion_vec, exclusion_vec + + def forward(self, in_ex_feature): + result = [self.forward_single(in_mat, ex_mat) for in_mat, ex_mat in in_ex_feature] + inclusion_mat = [in_vec for in_vec, ex_vec in result] + inclusion_mat = torch.cat(inclusion_mat, 0) #### 32,768 + exclusion_mat = [ex_vec for in_vec, ex_vec in result] + exclusion_mat = torch.cat(exclusion_mat, 0) #### 32,768 + protocol_mat = torch.cat([inclusion_mat, exclusion_mat], 1) + output = self.f(self.fc(protocol_mat)) + return output + + @property + def embedding_size(self): + return self.output_dim + + + +if __name__ == "__main__": + # protocols = get_all_protocols() + # split_protocols(protocols) + save_sentence_bert_dict_pkl() + + + + + + + + + + +