|
a |
|
b/code/dnc_code/test.py |
|
|
1 |
# Training DNC |
|
|
2 |
|
|
|
3 |
import sys |
|
|
4 |
import time |
|
|
5 |
import torch |
|
|
6 |
import random |
|
|
7 |
import numpy as np |
|
|
8 |
|
|
|
9 |
# torch.autograd.set_detect_anomaly(True) # Setting Anomaly Detection True for finding bad operations |
|
|
10 |
|
|
|
11 |
####### Following function is adapted from the implemention by loudinthecloud on Github ########## |
|
|
12 |
def random_seed(): |
|
|
13 |
seed = int(time.time()*10000000) |
|
|
14 |
random.seed(seed) |
|
|
15 |
np.random.seed(int(seed/10000000)) # NumPy seed Range is 2**32 - 1 max |
|
|
16 |
torch.manual_seed(seed) |
|
|
17 |
########################################################################################################## |
|
|
18 |
|
|
|
19 |
def main(): |
|
|
20 |
if len(sys.argv) > 4: |
|
|
21 |
if sys.argv[1] == "1": |
|
|
22 |
if sys.argv[2] == "GPU" and torch.cuda.is_available(): # Checking if GPU Request is given or not and availability of CUDA |
|
|
23 |
from tasks.babi_task_GPU import task_babi |
|
|
24 |
elif sys.argv[2] == "CPU": |
|
|
25 |
from tasks.babi_task import task_babi |
|
|
26 |
else: |
|
|
27 |
print("Please specify the run device (GPU/CPU)") |
|
|
28 |
exit() |
|
|
29 |
c_task = task_babi() # Initialization of the bAbI Task |
|
|
30 |
print("\nStarting bAbI Question Answering Task for DNC\n") |
|
|
31 |
elif sys.argv[1] == "2": |
|
|
32 |
if sys.argv[2] == "GPU" and torch.cuda.is_available(): # Checking if GPU Request is given or not and availability of CUDA |
|
|
33 |
from tasks.ner_task_bio_GPU import task_NER |
|
|
34 |
elif sys.argv[2] == "CPU": |
|
|
35 |
from tasks.ner_task_bio import task_NER |
|
|
36 |
else: |
|
|
37 |
print("Please specify the run device (GPU/CPU)") |
|
|
38 |
exit() |
|
|
39 |
c_task = task_NER() # Initialization of the NER Task (This is for BIO Tagging. Edit at line 33 and 35 for BIEOS tagging) |
|
|
40 |
print("\nStarting Medical NER Task for DNC\n") |
|
|
41 |
else: |
|
|
42 |
print("Unidentified task, please refer README file") |
|
|
43 |
exit() |
|
|
44 |
else: |
|
|
45 |
print("Incorrect Number of arguments") |
|
|
46 |
exit() |
|
|
47 |
|
|
|
48 |
epoch = sys.argv[3] # Last Epoch number till the model was trained (eg: 0) |
|
|
49 |
batch = sys.argv[4] # Last Batch Number till the model was trained (eg: 1000) |
|
|
50 |
batch_size = 1 |
|
|
51 |
|
|
|
52 |
# Random Seed |
|
|
53 |
random_seed() |
|
|
54 |
|
|
|
55 |
c_task.init_dnc() |
|
|
56 |
c_task.init_loss() |
|
|
57 |
c_task.batch_size = batch_size |
|
|
58 |
c_task.load_model(2, epoch, batch) |
|
|
59 |
results = c_task.test_model() |
|
|
60 |
|
|
|
61 |
print(results) |
|
|
62 |
|
|
|
63 |
if __name__ == '__main__': |
|
|
64 |
main() |