Diff of /code/dnc_code/train.py [000000] .. [349d16]

Switch to unified view

a b/code/dnc_code/train.py
1
# Training DNC
2
3
import sys
4
import time
5
import torch
6
import random
7
import numpy as np
8
9
# torch.autograd.set_detect_anomaly(True) # Setting Anomaly Detection True for finding bad operations
10
11
####### Following function is adapted from the NTM implemention by loudinthecloud on Github ##########
12
def random_seed():
13
    seed = int(time.time()*10000000)
14
    random.seed(seed)
15
    np.random.seed(int(seed/10000000))      # NumPy seed Range is 2**32 - 1 max
16
    torch.manual_seed(seed)
17
##########################################################################################################
18
19
def main():
20
    if len(sys.argv) > 2:
21
        if sys.argv[1] == "1":
22
            if sys.argv[2] == "GPU" and torch.cuda.is_available():  # Checking if GPU Request is given or not and availability of CUDA
23
                from tasks.babi_task_GPU import task_babi
24
            elif sys.argv[2] == "CPU":
25
                from tasks.babi_task import task_babi
26
            else:
27
                print("Please specify the run device (GPU/CPU)")
28
                exit()
29
            c_task = task_babi()                    # Initialization of the bAbI Task
30
            print("\nStarting bAbI Question Answering Task for DNC\n")
31
        elif sys.argv[1] == "2":
32
            if sys.argv[2] == "GPU" and torch.cuda.is_available():  # Checking if GPU Request is given or not and availability of CUDA
33
                from tasks.ner_task_bio_GPU import task_NER
34
            elif sys.argv[2] == "CPU":
35
                from tasks.ner_task_bio import task_NER
36
            else:
37
                print("Please specify the run device (GPU/CPU)")
38
                exit()
39
            c_task = task_NER()                    # Initialization of the NER Task (This is for BIO Tagging. Edit at line 33 and 35 for BIEOS tagging)
40
            print("\nStarting Medical NER Task for DNC\n")
41
        else:
42
            print("Unidentified task, please refer README file")
43
            exit()
44
    else:
45
        print("Incorrect Number of arguments")
46
        exit()
47
48
    # Random Seed
49
    random_seed()
50
51
    c_task.init_dnc()
52
    c_task.init_loss()
53
    c_task.init_optimizer()
54
55
    c_task.train_model()
56
    print("Training Completed!")
57
58
if __name__ == '__main__':
59
    main()