Diff of /code/dnc_code/test.py [000000] .. [349d16]

Switch to unified view

a b/code/dnc_code/test.py
1
# Training DNC
2
3
import sys
4
import time
5
import torch
6
import random
7
import numpy as np
8
9
# torch.autograd.set_detect_anomaly(True) # Setting Anomaly Detection True for finding bad operations
10
11
####### Following function is adapted from the implemention by loudinthecloud on Github ##########
12
def random_seed():
13
    seed = int(time.time()*10000000)
14
    random.seed(seed)
15
    np.random.seed(int(seed/10000000))      # NumPy seed Range is 2**32 - 1 max
16
    torch.manual_seed(seed)
17
##########################################################################################################
18
19
def main():
20
    if len(sys.argv) > 4:
21
        if sys.argv[1] == "1":
22
            if sys.argv[2] == "GPU" and torch.cuda.is_available():  # Checking if GPU Request is given or not and availability of CUDA
23
                from tasks.babi_task_GPU import task_babi
24
            elif sys.argv[2] == "CPU":
25
                from tasks.babi_task import task_babi
26
            else:
27
                print("Please specify the run device (GPU/CPU)")
28
                exit()
29
            c_task = task_babi()                    # Initialization of the bAbI Task
30
            print("\nStarting bAbI Question Answering Task for DNC\n")
31
        elif sys.argv[1] == "2":
32
            if sys.argv[2] == "GPU" and torch.cuda.is_available():  # Checking if GPU Request is given or not and availability of CUDA
33
                from tasks.ner_task_bio_GPU import task_NER
34
            elif sys.argv[2] == "CPU":
35
                from tasks.ner_task_bio import task_NER
36
            else:
37
                print("Please specify the run device (GPU/CPU)")
38
                exit()
39
            c_task = task_NER()                    # Initialization of the NER Task (This is for BIO Tagging. Edit at line 33 and 35 for BIEOS tagging)
40
            print("\nStarting Medical NER Task for DNC\n")
41
        else:
42
            print("Unidentified task, please refer README file")
43
            exit()
44
    else:
45
        print("Incorrect Number of arguments")
46
        exit()
47
48
    epoch = sys.argv[3] # Last Epoch number till the model was trained (eg: 0)
49
    batch = sys.argv[4] # Last Batch Number till the model was trained (eg: 1000)
50
    batch_size = 1
51
52
    # Random Seed
53
    random_seed()
54
55
    c_task.init_dnc()
56
    c_task.init_loss()
57
    c_task.batch_size = batch_size
58
    c_task.load_model(2, epoch, batch)
59
    results = c_task.test_model()
60
61
    print(results)
62
63
if __name__ == '__main__':
64
    main()