[bc9e98]: / benchmark / protocol_encode.py

Download this file

150 lines (120 with data), 4.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
'''
input:
data/raw_data.csv
output:
data/sentence2embedding.pkl (preprocessing)
protocol_embedding
'''
import csv, pickle
from functools import reduce
from tqdm import tqdm
import torch
torch.manual_seed(0)
from torch import nn
import torch.nn.functional as F
def clean_protocol(protocol):
protocol = protocol.lower()
protocol_split = protocol.split('\n')
filter_out_empty_fn = lambda x: len(x.strip())>0
strip_fn = lambda x:x.strip()
protocol_split = list(filter(filter_out_empty_fn, protocol_split))
protocol_split = list(map(strip_fn, protocol_split))
return protocol_split
def get_all_protocols():
input_file = 'data/raw_data.csv'
with open(input_file, 'r') as csvfile:
rows = list(csv.reader(csvfile, delimiter = ','))[1:]
protocols = [row[9] for row in rows]
return protocols
def split_protocol(protocol):
protocol_split = clean_protocol(protocol)
inclusion_idx, exclusion_idx = len(protocol_split), len(protocol_split)
for idx, sentence in enumerate(protocol_split):
if "inclusion" in sentence:
inclusion_idx = idx
break
for idx, sentence in enumerate(protocol_split):
if "exclusion" in sentence:
exclusion_idx = idx
break
if inclusion_idx + 1 < exclusion_idx + 1 < len(protocol_split):
inclusion_criteria = protocol_split[inclusion_idx:exclusion_idx]
exclusion_criteria = protocol_split[exclusion_idx:]
if not (len(inclusion_criteria) > 0 and len(exclusion_criteria) > 0):
print(len(inclusion_criteria), len(exclusion_criteria), len(protocol_split))
exit()
return inclusion_criteria, exclusion_criteria ## list, list
else:
return protocol_split,
def collect_cleaned_sentence_set():
protocol_lst = get_all_protocols()
cleaned_sentence_lst = []
for protocol in protocol_lst:
result = split_protocol(protocol)
cleaned_sentence_lst.extend(result[0])
if len(result)==2:
cleaned_sentence_lst.extend(result[1])
return set(cleaned_sentence_lst)
def save_sentence_bert_dict_pkl():
cleaned_sentence_set = collect_cleaned_sentence_set()
from biobert_embedding.embedding import BiobertEmbedding
biobert = BiobertEmbedding()
def text2vec(text):
return biobert.sentence_vector(text)
protocol_sentence_2_embedding = dict()
for sentence in tqdm(cleaned_sentence_set):
protocol_sentence_2_embedding[sentence] = text2vec(sentence)
pickle.dump(protocol_sentence_2_embedding, open('data/sentence2embedding.pkl', 'wb'))
return
def load_sentence_2_vec():
sentence_2_vec = pickle.load(open('data/sentence2embedding.pkl', 'rb'))
return sentence_2_vec
def protocol2feature(protocol, sentence_2_vec):
result = split_protocol(protocol)
inclusion_criteria, exclusion_criteria = result[0], result[-1]
inclusion_feature = [sentence_2_vec[sentence].view(1,-1) for sentence in inclusion_criteria if sentence in sentence_2_vec]
exclusion_feature = [sentence_2_vec[sentence].view(1,-1) for sentence in exclusion_criteria if sentence in sentence_2_vec]
if inclusion_feature == []:
inclusion_feature = torch.zeros(1,768)
else:
inclusion_feature = torch.cat(inclusion_feature, 0)
if exclusion_feature == []:
exclusion_feature = torch.zeros(1,768)
else:
exclusion_feature = torch.cat(exclusion_feature, 0)
return inclusion_feature, exclusion_feature
class Protocol_Embedding(nn.Sequential):
def __init__(self, output_dim, highway_num, device ):
super(Protocol_Embedding, self).__init__()
self.input_dim = 768
self.output_dim = output_dim
self.highway_num = highway_num
self.fc = nn.Linear(self.input_dim*2, output_dim)
self.f = F.relu
self.device = device
self = self.to(device)
def forward_single(self, inclusion_feature, exclusion_feature):
## inclusion_feature, exclusion_feature: xxx,768
inclusion_feature = inclusion_feature.to(self.device)
exclusion_feature = exclusion_feature.to(self.device)
inclusion_vec = torch.mean(inclusion_feature, 0)
inclusion_vec = inclusion_vec.view(1,-1)
exclusion_vec = torch.mean(exclusion_feature, 0)
exclusion_vec = exclusion_vec.view(1,-1)
return inclusion_vec, exclusion_vec
def forward(self, in_ex_feature):
result = [self.forward_single(in_mat, ex_mat) for in_mat, ex_mat in in_ex_feature]
inclusion_mat = [in_vec for in_vec, ex_vec in result]
inclusion_mat = torch.cat(inclusion_mat, 0) #### 32,768
exclusion_mat = [ex_vec for in_vec, ex_vec in result]
exclusion_mat = torch.cat(exclusion_mat, 0) #### 32,768
protocol_mat = torch.cat([inclusion_mat, exclusion_mat], 1)
output = self.f(self.fc(protocol_mat))
return output
@property
def embedding_size(self):
return self.output_dim
if __name__ == "__main__":
# protocols = get_all_protocols()
# split_protocols(protocols)
save_sentence_bert_dict_pkl()