|
a |
|
b/code/dnc_code/train.py |
|
|
1 |
# Training DNC |
|
|
2 |
|
|
|
3 |
import sys |
|
|
4 |
import time |
|
|
5 |
import torch |
|
|
6 |
import random |
|
|
7 |
import numpy as np |
|
|
8 |
|
|
|
9 |
# torch.autograd.set_detect_anomaly(True) # Setting Anomaly Detection True for finding bad operations |
|
|
10 |
|
|
|
11 |
####### Following function is adapted from the NTM implemention by loudinthecloud on Github ########## |
|
|
12 |
def random_seed(): |
|
|
13 |
seed = int(time.time()*10000000) |
|
|
14 |
random.seed(seed) |
|
|
15 |
np.random.seed(int(seed/10000000)) # NumPy seed Range is 2**32 - 1 max |
|
|
16 |
torch.manual_seed(seed) |
|
|
17 |
########################################################################################################## |
|
|
18 |
|
|
|
19 |
def main(): |
|
|
20 |
if len(sys.argv) > 2: |
|
|
21 |
if sys.argv[1] == "1": |
|
|
22 |
if sys.argv[2] == "GPU" and torch.cuda.is_available(): # Checking if GPU Request is given or not and availability of CUDA |
|
|
23 |
from tasks.babi_task_GPU import task_babi |
|
|
24 |
elif sys.argv[2] == "CPU": |
|
|
25 |
from tasks.babi_task import task_babi |
|
|
26 |
else: |
|
|
27 |
print("Please specify the run device (GPU/CPU)") |
|
|
28 |
exit() |
|
|
29 |
c_task = task_babi() # Initialization of the bAbI Task |
|
|
30 |
print("\nStarting bAbI Question Answering Task for DNC\n") |
|
|
31 |
elif sys.argv[1] == "2": |
|
|
32 |
if sys.argv[2] == "GPU" and torch.cuda.is_available(): # Checking if GPU Request is given or not and availability of CUDA |
|
|
33 |
from tasks.ner_task_bio_GPU import task_NER |
|
|
34 |
elif sys.argv[2] == "CPU": |
|
|
35 |
from tasks.ner_task_bio import task_NER |
|
|
36 |
else: |
|
|
37 |
print("Please specify the run device (GPU/CPU)") |
|
|
38 |
exit() |
|
|
39 |
c_task = task_NER() # Initialization of the NER Task (This is for BIO Tagging. Edit at line 33 and 35 for BIEOS tagging) |
|
|
40 |
print("\nStarting Medical NER Task for DNC\n") |
|
|
41 |
else: |
|
|
42 |
print("Unidentified task, please refer README file") |
|
|
43 |
exit() |
|
|
44 |
else: |
|
|
45 |
print("Incorrect Number of arguments") |
|
|
46 |
exit() |
|
|
47 |
|
|
|
48 |
# Random Seed |
|
|
49 |
random_seed() |
|
|
50 |
|
|
|
51 |
c_task.init_dnc() |
|
|
52 |
c_task.init_loss() |
|
|
53 |
c_task.init_optimizer() |
|
|
54 |
|
|
|
55 |
c_task.train_model() |
|
|
56 |
print("Training Completed!") |
|
|
57 |
|
|
|
58 |
if __name__ == '__main__': |
|
|
59 |
main() |