[349d16]: / code / dnc_code / tasks / ner_task_bio.py

Download this file

613 lines (489 with data), 32.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
# Named Entity Recognition on Medical Data (BIO Tagging)
# Bio-Word2Vec Embeddings Source and Reference: https://github.com/ncbi-nlp/BioWordVec
import os
import re
import torch
import pickle
from torch import nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import random
from DNC.dnc import DNC_Module # Importing DNC Implementation
class task_NER():
def __init__(self):
self.name = "NER_task_bio"
# Controller Params
self.controller_size = 128
self.controller_layers = 1
# Head Params
self.num_read_heads = 1
self.num_write_heads = 1
# Processor Params
self.num_inputs = 200 # Length of Embeddings
self.num_outputs = 7 # Class size
# Memory Params
self.memory_N = 128
self.memory_M = 128
# Training Params
self.num_batches = -1
self.save_batch = 5 # Saving model after every save_batch number of batches
self.batch_size = 10
self.num_epoch = 4
# Optimizer Params
self.adam_lr = 1e-4
self.adam_betas = (0.9, 0.999)
self.adam_eps = 1e-8
# Handles
self.machine = None
self.loss = None
self.optimizer = None
# Class Dictionaries
self.labelDict = None # Label Dictionary - Labels to Index
self.reverseDict = None # Inverse Label Dictionary - Index to Labels
# File Paths
self.concept_path_train = "../medical_data/train_data/concept" # Path to train concept files
self.text_path_train = "../medical_data/train_data/txt" # Path to train text summaries
self.concept_path_test = "../medical_data/test_data/concept" # Path to test concept files
self.text_path_test = "../medical_data/test_data/txt" # Path to test text summaries
self.save_path = "../medical_data/cleaned_files" # Save path
self.embed_dic_path = "../medical_data/embeddings/bio_embedding_dictionary.dat" # Word2Vec embeddings Dictionary path
self.random_vec = "../medical_data/embeddings/random_vec.dat" # Path to random embedding (Used to create new vectors)
self.model_path = "../saved_models/" # Stores Trained Models
# Miscellaneous
self.padding_symbol = np.full((self.num_inputs), 0.01) # Padding symbol embedding
def get_task_name(self):
return self.name
def init_dnc(self):
self.machine = DNC_Module(self.num_inputs, self.num_outputs, self.controller_size, self.controller_layers, self.num_read_heads, self.num_write_heads, self.memory_N, self.memory_M)
def init_loss(self):
self.loss = nn.CrossEntropyLoss(reduction = 'mean') # Cross Entropy Loss -> Softmax Activation + Cross Entropy Loss
def init_optimizer(self):
self.optimizer = optim.Adam(self.machine.parameters(), lr = self.adam_lr, betas = self.adam_betas, eps = self.adam_eps)
def calc_loss(self, Y_pred, Y):
# Y: dim -> (sequence_len x batch_size)
# Y_pred: dim -> (sequence_len x batch_size x num_outputs)
loss_vec = torch.empty(Y.shape[0], dtype=torch.float32)
for i in range(Y_pred.shape[0]):
loss_vec[i] = self.loss(Y_pred[i], Y[i])
return torch.mean(loss_vec)
def calc_cost(self, Y_pred, Y): # Calculates % Cost
# Y: dim -> (sequence_len x batch_size)
# Y_pred: dim -> (sequence_len x batch_size x sequence_width)
'''
Note:
1). For considering an prediction to be True Positive, prediction must match completely with labels entity (not partially). Else it is False Negative.
2). For considering a prediction to be False Positive, it must be full entity (BIII) and not completely match the label entity.
'''
# Stores correct class labels for each entity type
class_bag = {}
class_bag['problem'] = 0 # Total labels
class_bag['test'] = 0 # Total labels
class_bag['treatment'] = 0 # Total labels
class_bag['problem_cor'] = 0 # Correctly classified labels
class_bag['test_cor'] = 0 # Correctly classified labels
class_bag['treatment_cor'] = 0 # Correctly classified labels
class_bag['problem_fp'] = 0 # False positive classified labels
class_bag['test_fp'] = 0 # False positive classified labels
class_bag['treatment_fp'] = 0 # False positive classified labels
pred_class = np.transpose(F.softmax(Y_pred, dim=2).max(2)[1].numpy()).reshape(-1) # Predicted class. dim -> (sequence_len*batch_size)
Y = np.transpose(Y.numpy()).reshape(-1) # Converting to NumPy Array and linearizing
cor_pred = (Y == pred_class).astype(np.int) # Comparing Prediction and Labels to find correct predictions
class_bag['word_pred_acc'] = np.divide(np.sum(cor_pred), cor_pred.size)*100.0 # % Accuracy of Correctly Predicted Words (Not Entities)
# Getting the beginning index of all the entities
beg_idx = list(np.where(np.in1d(Y, [0, 2, 4]))[0])
# Getting the end index of all the entities (All the Index previous of 'Other'/'Begin' and not equal to 'Other')
target = np.where(np.in1d(Y, [0, 2, 4, 6]))[0] - 1
if target[0] == -1:
target = target[1:]
end_idx = list(target[np.where(Y[target] != 6)[0]])
if Y[-1] != 6:
end_idx.append(Y.size-1)
assert len(beg_idx) == len(end_idx) # Sanity Check
class_bag['total'] = len(beg_idx) # Total number of Entities
# Counting Entities
sum_vec = np.cumsum(cor_pred) # Calculates cumulative summation of predicted vector
for b, e in zip(beg_idx, end_idx):
idx_range = e-b+1 # Entity span
sum_range = sum_vec[e]-sum_vec[b]+1 # Count of entity elements which are predicted correctly
lab = self.reverseDict[Y[b]][2:] # Extracting entity type (Problem, Test or Treatment)
class_bag[lab] = class_bag[lab]+1 # Getting count of each entities
if sum_range == idx_range: # +1 if entity is classified correctly
class_bag[lab+'_cor'] = class_bag[lab+'_cor']+1
# Detecting False Positives
# Getting the beginning index of all the entities in Predicted Results
beg_idx_p = list(np.where(np.in1d(pred_class, [0, 2, 4]))[0])
for b in beg_idx_p:
if cor_pred[b] == 0:
lab = self.reverseDict[pred_class[b]][2:]
class_bag[lab+'_fp'] = class_bag[lab+'_fp']+1
return class_bag
def print_word(self, token_class): # Prints the Class name from Class number
word = self.reverseDict[token_class]
print(word + "\n")
def clip_grads(self): # Clipping gradients for stability
"""Gradient clipping to the range [10, 10]."""
parameters = list(filter(lambda p: p.grad is not None, self.machine.parameters()))
for p in parameters:
p.grad.data.clamp_(-10, 10)
def initialize_labels(self): # Initializing label dictionaries for Labels->IDX and IDX->Labels
self.labelDict = {} # Label Dictionary - Labels to Index
self.reverseDict = {} # Inverse Label Dictionary - Index to Labels
# Using BIEOS labelling scheme
self.labelDict['b-problem'] = 0 # Problem - Beginning
self.labelDict['i-problem'] = 1 # Problem - Inside
self.labelDict['b-test'] = 2 # Test - Beginning
self.labelDict['i-test'] = 3 # Test - Inside
self.labelDict['b-treatment'] = 4 # Treatment - Beginning
self.labelDict['i-treatment'] = 5 # Treatment - Inside
self.labelDict['o'] = 6 # Outside Token
# Making Inverse Label Dictionary
for k in self.labelDict.keys():
self.reverseDict[self.labelDict[k]] = k
# Saving the diictionaries into a file
self.save_data([self.labelDict, self.reverseDict], os.path.join(self.save_path, "label_dicts_bio.dat"))
def parse_concepts(self, file_path): # Parses the concept file to extract concepts and labels
conceptList = [] # Stores all the Concept in the File
f = open(file_path) # Opening and reading a concept file
content = f.readlines() # Reading all the lines in the concept file
f.close() # Closing the concept file
for x in content: # Reading each line in the concept file
dic = {}
# Cleaning and extracting the entities, labels and their positions in the corresponding medical summaries
x = re.sub('\n', ' ', x)
x = re.sub(r'\ +', ' ', x)
x = x.strip().split('||')
temp1, label = x[0].split(' '), x[1].split('=')[1][1:-1]
temp1[0] = temp1[0][3:]
temp1[-3] = temp1[-3][0:-1]
entity = temp1[0:-2]
if len(entity) >= 1:
lab = ['i']*len(entity)
lab[0] = 'b'
lab = [l+"-"+label for l in lab]
else:
print("Data in File: " + file_path + ", not in expected format..")
exit()
noLab = [self.labelDict[l] for l in lab]
sLine, sCol = int(temp1[-2].split(":")[0]), int(temp1[-2].split(":")[1])
eLine, eCol = int(temp1[-1].split(":")[0]), int(temp1[-1].split(":")[1])
'''
# Printing the information
print("------------------------------------------------------------")
print("Entity: " + str(entity))
print("Entity Label: " + label)
print("Labels - BIO form: " + str(lab))
print("Labels Index: " + str(noLab))
print("Start Line: " + str(sLine) + ", Start Column: " + str(sCol))
print("End Line: " + str(eLine) + ", End Column: " + str(eCol))
print("------------------------------------------------------------")
'''
# Storing the information as a dictionary
dic['entity'] = entity # Entity Name (In the form of list of words)
dic['label'] = label # Common Label
dic['BIO_labels'] = lab # List of BIO labels for each word
dic['label_index'] = noLab # Labels in the index form
dic['start_line'] = sLine # Start line of the concept in the corresponding text summaries
dic['start_word_no'] = sCol # Starting word number of the concept in the corresponding start line
dic['end_line'] = eLine # End line of the concept in the corresponding text summaries
dic['end_word_no'] = eCol # Ending word number of the concept in the corresponding end line
# Appending the concept dictionary to the list
conceptList.append(dic)
return conceptList # Returning the all the concepts in the current file in the form of dictionary list
def parse_summary(self, file_path): # Parses the Text summaries
file_lines = [] # Stores the lins of files in the list form
tags = [] # Stores corresponding labels for each word in the file (Default label: 'o' [Outside])
default_label = len(self.labelDict)-1 # default_label is "7" (Corresponding to 'Other' entity)
# counter = 1 # Temporary variable used during print
f = open(file_path) # Opening and reading a concept file
content = f.readlines() # Reading all the lines in the concept file
f.close()
for x in content:
x = re.sub('\n', ' ', x)
x = re.sub(r'\ +', ' ', x)
file_lines.append(x.strip().split(" ")) # Spliting the lines into word list and Appending each of them in the file list
tags.append([default_label]*len(file_lines[-1])) # Assigining the default_label to all the words in a line
'''
# Printing the information
print("------------------------------------------------------------")
print("File Lines No: " + str(counter))
print(file_lines[-1])
print("\nCorresponding labels:")
print(tags[-1])
print("------------------------------------------------------------")
counter += 1
'''
assert len(tags[-1]) == len(file_lines[-1]), "Line length is not matching labels length..." # Sanity Check
return file_lines, tags
def modify_labels(self, conceptList, tags): # Modifies the default labels of each word in text files with the true labels from the concept files
for e in conceptList: # Iterating over all the dictionary elements in the Concept List
if e['start_line'] == e['end_line']: # Checking whether concept is spanning over a single line or multiple line in the summary
tags[e['start_line']-1][e['start_word_no']:e['end_word_no']+1] = e['label_index'][:]
else:
start = e['start_line']
end = e['end_line']
beg = 0
for i in range(start, end+1): # Distributing labels over multiple lines in the text summaries
if i == start:
tags[i-1][e['start_word_no']:] = e['label_index'][0:len(tags[i-1])-e['start_word_no']]
beg = len(tags[i-1])-e['start_word_no']
elif i == end:
tags[i-1][0:e['end_word_no']+1] = e['label_index'][beg:]
else:
tags[i-1][:] = e['label_index'][beg:beg+len(tags[i-1])]
beg = beg+len(tags[i-1])
return tags
def print_data(self, file, file_lines, tags): # Prints the given data
counter = 1
print("\n************ Printing details of the file: " + file + " ************\n")
for x in file_lines:
print("------------------------------------------------------------")
print("File Lines No: " + str(counter))
print(x)
print("\nCorresponding labels:")
print([self.reverseDict[i] for i in tags[counter-1]])
print("\nCorresponding Label Indices:")
print(tags[counter-1])
print("------------------------------------------------------------")
counter += 1
def save_data(self, obj_list, s_path): # Saves the file into the binary file using Pickle
# Note: The 'obj_list' must be a list and none other than that
pickle.dump(tuple(obj_list), open(s_path,'wb'))
def acquire_data(self, task): # Read all the concept files to get concepts and labels, proces them and save them
data = {} # Dictionary to store all the data objects (conceptList, file_lines, tags) each indexed by file name
if task == 'train': # Determining the task type to assign the data path accordingly
t_path = self.text_path_train
c_path = self.concept_path_train
else:
t_path = self.text_path_test
c_path = self.concept_path_test
for f in os.listdir(t_path):
f1 = f.split('.')[0] + ".con"
if os.path.isfile(os.path.join(c_path, f1)):
conceptList = self.parse_concepts(os.path.join(c_path, f1)) # Parsing concepts and labels from the corresponding concept file
file_lines, tags = self.parse_summary(os.path.join(t_path, f)) # Parses the document summaries to get the written notes
tags = self.modify_labels(conceptList, tags) # Modifies he default labels to each word with the true labels from the concept files
data[f1] = [conceptList, file_lines, tags] # Storing each object in dictionary
# self.print_data(f, file_lines, tags) # Printing the details
return data
def structure_data(self, data_dict): # Structures the data in proper trainable form
final_line_list = [] # Stores words of all the files in separate sub-lists
final_tag_list = [] # Stores tags of all the files in separate sub-lists
for k in data_dict.keys(): # Extracting data from each pre-processed file in dictionary
file_lines = data_dict[k][1] # Extracting story
tags = data_dict[k][2] # Extracting corresponding labels
# Creating empty lists
temp1 = []
temp2 = []
# Merging all the lines in file into a single list. Same for corresponding labels
for i in range(len(file_lines)):
temp1.extend(file_lines[i])
temp2.extend(tags[i])
assert len(temp1) == len(temp2), "Word length not matching Label length for story in " + str(k) # Sanity Check
final_line_list.append(temp1)
final_tag_list.append(temp2)
assert len(final_line_list) == len(final_tag_list), "Number of stories not matching number of labels list" # Sanity Check
return final_line_list, final_tag_list
def padding(self, line_list, tag_list): # Pads stories with padding symbol to make them of same length
diff = 0
max_len = 0
outside_class = len(self.labelDict)-1 # Classifying padding symbol as "outside" term
# Calculating Max Summary Length
for i in range(len(line_list)):
if len(line_list[i])>max_len:
max_len = len(line_list[i])
for i in range(len(line_list)):
diff = max_len - len(line_list[i])
line_list[i].extend([self.padding_symbol]*diff)
tag_list[i].extend([outside_class]*diff)
assert (len(line_list[i]) == max_len) and (len(line_list[i]) == len(tag_list[i])), "Padding unsuccessful" # Sanity check
return np.asarray(line_list), np.asarray(tag_list) # Making NumPy array of size (batch_size x story_length x word size) and (batch_size x story_length x 1) respectively
def embed_input(self, line_list): # Converts words to vector embeddings
final_list = [] # Stores embedded words
summary = None # Temp variable
word = None # Temp variable
temp = None # Temp variable
embed_dic = pickle.load(open(self.embed_dic_path, 'rb')) # Loading word2vec dictionary using Pickle
r_embed = pickle.load(open(self.random_vec, 'rb')) # Loading Random embedding
for i in range(len(line_list)): # Iterating over all the summaries
summary = line_list[i]
final_list.append([]) # Reserving space for curent summary
for j in range(len(summary)):
word = summary[j].lower()
if word in embed_dic: # Checking for existence of word in dictionary
final_list[-1].append(embed_dic[word])
else:
temp = r_embed[:] # Copying the values of the list
random.shuffle(temp) # Randomly shuffling the word embedding to make it unique
temp = np.asarray(temp, dtype=np.float32) # Converting to NumPy array
final_list[-1].append(temp)
return final_list
def prepare_data(self, task='train'): # Preparing all the data necessary
line_list, tag_list = None, None
'''
line_list is the list of rows, where each row is a list of all the words in a medical summary
Similar is the case for tag_list, except, it stores labels for each words
'''
if not os.path.exists(self.save_path):
os.mkdir(self.save_path) # Creating a new directory if it does not exist else reading previously saved data
if not os.path.exists(os.path.join(self.save_path, "label_dicts_bio.dat")):
self.initialize_labels() # Initialize label to index dictionaries
else:
self.labelDict, self.reverseDict = pickle.load(open(os.path.join(self.save_path, "label_dicts_bio.dat"), 'rb')) # Loading Label dictionaries
if not os.path.exists(os.path.join(self.save_path, "object_dict_bio_"+str(task)+".dat")):
data_dict = self.acquire_data(task) # Read data from file
line_list, tag_list = self.structure_data(data_dict) # Structures the data into proper form
line_list = self.embed_input(line_list) # Embeds input data (words) into embeddings
self.save_data([line_list, tag_list], os.path.join(self.save_path, "object_dict_bio_"+str(task)+".dat"))
else:
line_list, tag_list = pickle.load(open(os.path.join(self.save_path, "object_dict_bio_"+str(task)+".dat"), 'rb')) # Loading Data dictionary
return line_list, tag_list
def get_data(self, task='train'):
line_list, tag_list = self.prepare_data(task)
# Shuffling stories
story_idx = list(range(0, len(line_list)))
random.shuffle(story_idx)
num_batch = int(len(story_idx)/self.batch_size)
self.num_batches = num_batch
# Out Data
x_out = []
y_out = []
counter = 1
for i in story_idx:
if num_batch<=0:
break
x_out.append(line_list[i])
y_out.append(tag_list[i])
if counter % self.batch_size == 0:
counter = 0
# Padding and converting labels to one hot vectors
x_out_pad, y_out_pad = self.padding(x_out, y_out)
x_out_array = torch.tensor(x_out_pad.swapaxes(0, 1), dtype=torch.float32) # Converting from (batch_size x story_length x word size) to (story_length x batch_size x word size)
y_out_array = torch.tensor(y_out_pad.swapaxes(0, 1), dtype=torch.long) # Converting from (batch_size x story_length x 1) to (story_length x batch_size x 1)
x_out = []
y_out = []
num_batch -= 1
yield (self.num_batches - num_batch), x_out_array, y_out_array
counter += 1
def train_model(self):
# Here, the model is optimized using Cross Entropy Loss.
loss_list = []
seq_length = []
last_batch = 0
# self.load_model(1, 99, 13) # Loading Pre-Trained model to train further
for j in range(self.num_epoch):
for batch_num, X, Y in self.get_data(task='train'):
self.optimizer.zero_grad() # Making old gradients zero before calculating the fresh ones
self.machine.initialization(self.batch_size) # Initializing states
Y_out = torch.empty((X.shape[0], X.shape[1], self.num_outputs), dtype=torch.float32) # dim: (seq_len x batch_size x num_output)
# Feeding the DNC network all the data first and then predicting output
# by giving zero vector as input and previous read states and hidden vector
# and thus training vector this way to give outputs matching the labels
embeddings = self.machine.backward_prediction(X) # Creating embeddings from data for backward calculation
temp_size = X.shape[0]
for i in range(temp_size):
Y_out[i, :, :], _ = self.machine(X[i], embeddings[temp_size-i-1]) # Passing Embeddings from backwards
loss = self.calc_loss(Y_out, Y)
loss.backward()
self.clip_grads()
self.optimizer.step()
class_bag = self.calc_cost(Y_out, Y)
corr = class_bag['problem_cor']+class_bag['test_cor']+class_bag['treatment_cor']
tot = class_bag['total']
loss_list += [loss.item()]
seq_length += [Y.shape[0]]
if (batch_num % self.save_batch) == 0:
self.save_model(j, batch_num)
last_batch = batch_num
print("Epoch: " + str(j) + "/" + str(self.num_epoch) + ", Batch: " + str(batch_num) + "/" + str(self.num_batches) + ", Loss: {0:.2f}, ".format(loss.item()) + \
"Batch Accuracy (Entity Prediction): {0:.2f} %, ".format((float(corr)/float(tot))*100.0) + "Batch Accuracy (Word Prediction): {0:.2f} %".format(class_bag['word_pred_acc']))
self.save_model(j, last_batch)
def test_model(self): # Testing the model
correct = 0
total = 0
result_dict = {}
result_dict['total_problem'] = 0 # Total labels in data
result_dict['total_test'] = 0 # Total labels in data
result_dict['total_treatment'] = 0 # Total labels in data
result_dict['correct_problem'] = 0 # Correctly classified labels
result_dict['correct_test'] = 0 # Correctly classified labels
result_dict['correct_treatment'] = 0 # Correctly classified labels
result_dict['false_positive_problem'] = 0 # False Positive labels
result_dict['false_positive_test'] = 0 # False Positive labels
result_dict['false_positive_treatment'] = 0 # False Positive labels
print("\n")
for batch_num, X, Y in self.get_data(task='test'):
self.machine.initialization(self.batch_size) # Initializing states
Y_out = torch.empty((X.shape[0], X.shape[1], self.num_outputs), dtype=torch.float32) # dim: (seq_len x batch_size x num_output)
# Feeding the DNC network all the data first and then predicting output
# by giving zero vector as input and previous read states and hidden vector
# and thus training vector this way to give outputs matching the labels
embeddings = self.machine.backward_prediction(X) # Creating embeddings from data for backward calculation
temp_size = X.shape[0]
for i in range(temp_size):
Y_out[i, :, :], _ = self.machine(X[i], embeddings[temp_size-i-1])
class_bag = self.calc_cost(Y_out, Y)
corr = class_bag['problem_cor']+class_bag['test_cor']+class_bag['treatment_cor']
tot = class_bag['total']
result_dict['total_problem'] = result_dict['total_problem'] + class_bag['problem']
result_dict['total_test'] = result_dict['total_test'] + class_bag['test']
result_dict['total_treatment'] = result_dict['total_treatment'] + class_bag['treatment']
result_dict['correct_problem'] = result_dict['correct_problem'] + class_bag['problem_cor']
result_dict['correct_test'] = result_dict['correct_test'] + class_bag['test_cor']
result_dict['correct_treatment'] = result_dict['correct_treatment'] + class_bag['treatment_cor']
result_dict['false_positive_problem'] = result_dict['false_positive_problem'] + class_bag['problem_fp']
result_dict['false_positive_test'] = result_dict['false_positive_test'] + class_bag['test_fp']
result_dict['false_positive_treatment'] = result_dict['false_positive_treatment'] + class_bag['treatment_fp']
correct += corr
total += tot
print("Test Example " + str(batch_num) + "/" + str(self.num_batches) + " processed, Batch Accuracy: {0:.2f} %, ".format((float(corr)/float(tot))*100.0) + "Batch Accuracy (Word Prediction): {0:.2f} %".format(class_bag['word_pred_acc']))
result_dict['accuracy'] = (float(correct)/float(total))*100.0
result_dict = self.calc_metrics(result_dict)
print("\nOverall Entity Prediction Accuracy: {0:.2f} %".format(result_dict['accuracy']))
return result_dict
def calc_metrics(self, result_dict): # Calculates Certain Metrices
precision_p = float(result_dict['correct_problem'])/float(result_dict['correct_problem'] + result_dict['false_positive_problem']) # Problem Precision
recall_p = float(result_dict['correct_problem'])/float(result_dict['total_problem']) # Problem Recall
precision_ts = float(result_dict['correct_test'])/float(result_dict['correct_test'] + result_dict['false_positive_test']) # Test Precision
recall_ts = float(result_dict['correct_test'])/float(result_dict['total_test']) # Test Recall
precision_tr = float(result_dict['correct_treatment'])/float(result_dict['correct_treatment'] + result_dict['false_positive_treatment']) # Treatment Precision
recall_tr = float(result_dict['correct_treatment'])/float(result_dict['total_treatment']) # Treatment Recall
f_score_p = 2*precision_p*recall_p/(precision_p+recall_p) # Problem F1 Score
f_score_ts = 2*precision_ts*recall_ts/(precision_ts+recall_ts) # Test F1 Score
f_score_tr = 2*precision_tr*recall_tr/(precision_tr+recall_tr) # Treatment F1 Score
result_dict['problem_precision'] = precision_p
result_dict['problem_recall'] = recall_p
result_dict['problem_f1'] = f_score_p
result_dict['test_precision'] = precision_ts
result_dict['test_recall'] = recall_ts
result_dict['test_f1'] = f_score_ts
result_dict['treatment_precision'] = precision_tr
result_dict['treatment_recall'] = recall_tr
result_dict['treatment_f1'] = f_score_tr
result_dict['macro_average_f1'] = (f_score_p + f_score_ts + f_score_tr)/3.0 # Macro Average F1 Score
# Micro Average F1 Score
correct_sum = result_dict['correct_problem'] + result_dict['correct_test'] + result_dict['correct_treatment']
fp_sum = result_dict['false_positive_problem'] + result_dict['false_positive_test'] + result_dict['false_positive_treatment']
total_sum = result_dict['total_problem'] + result_dict['total_test'] + result_dict['total_treatment']
precision_avg = float(correct_sum)/float(correct_sum + fp_sum)
recall_avg = float(correct_sum)/float(total_sum)
result_dict['micro_average_f1'] = 2*precision_avg*recall_avg/(precision_avg+recall_avg)
return result_dict
def save_model(self, curr_epoch, curr_batch):
# Here 'start_epoch' and 'start_batch' params below are the 'epoch' and 'batch' number from which to start training after next model loading
# Note: It is recommended to start from the 'start_epoch' and not 'start_epoch' + 'start_batch', because batches are formed randomly
if not os.path.exists(os.path.join(self.model_path, self.name)):
os.mkdir(os.path.join(self.model_path, self.name))
state_dic = {'task_name': self.name, 'start_epoch': curr_epoch + 1, 'start_batch': curr_batch + 1, 'state_dict': self.machine.state_dict(), 'optimizer_dic' : self.optimizer.state_dict()}
filename = self.model_path + self.name + "/" + self.name + "_" + str(curr_epoch) + "_" + str(curr_batch) + "_saved_model.pth.tar"
torch.save(state_dic, filename)
def load_model(self, option, epoch, batch):
path = self.model_path + self.name + "/" + self.name + "_" + str(epoch) + "_" + str(batch) + "_saved_model.pth.tar"
if option == 1: # Loading for training
checkpoint = torch.load(path)
self.machine.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_dic'])
else: # Loading for testing
checkpoint = torch.load(path)
self.machine.load_state_dict(checkpoint['state_dict'])
self.machine.eval()