Diff of /Train_Script.py [000000] .. [b52eda]

Switch to unified view

a b/Train_Script.py
1
from torch_geometric import __version__ as pyg_version
2
from torch_geometric.loader import DataLoader
3
from torch import __version__ as torch_version
4
from torch import FloatTensor, max, save, load
5
from torch.cuda import is_available, set_device, empty_cache, device_count
6
from torch.distributed import init_process_group, barrier
7
from torch.distributed.optim.zero_redundancy_optimizer import ZeroRedundancyOptimizer
8
from torch.utils.data.distributed import DistributedSampler
9
from torch.nn.parallel import DistributedDataParallel as DDP
10
from torch.distributed.elastic.multiprocessing.errors import record
11
from torch.utils.tensorboard.writer import SummaryWriter
12
from torch.optim import Adam
13
from torch.nn import CrossEntropyLoss
14
from pandas import read_csv
15
from os import environ
16
from tqdm import tqdm
17
from numpy import array, isin, argwhere
18
19
from Network import CHD_GNN
20
from Utilities import CHD_Dataset, __Load_Adjacency__
21
from Metrics import Accuracy_Util
22
23
@record
24
def main():
25
    DIRECTORY = '/home/sojo/Documents/ImageCHD/ImageCHD_dataset/'
26
    EPOCHS = 200
27
28
    init_process_group('nccl')
29
    environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
30
    local_rank = int(environ['LOCAL_RANK'])
31
    global_rank = int(environ['RANK'])
32
    batch_size = int(environ['WORLD_SIZE'])
33
34
    if global_rank == 0:
35
        print('PyG version: ', pyg_version)
36
        print('Torch version: ', torch_version)
37
        print('GPU available: ', is_available())
38
        print('GPU count: ', device_count())
39
40
    set_device(local_rank)
41
    empty_cache()
42
43
    model = CHD_GNN().to('cuda:' + str(local_rank))
44
    model = DDP(model, device_ids = [local_rank])
45
    adjacency = __Load_Adjacency__(DIRECTORY + 'ADJACENCY/')
46
47
    train_metadata = read_csv(filepath_or_buffer = DIRECTORY + 'train_dataset_info.csv')
48
    train_dataset = CHD_Dataset(metadata = train_metadata, directory = DIRECTORY, adjacency = adjacency)
49
    train_sampler = DistributedSampler(train_dataset, shuffle = True)
50
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = 4,
51
                                num_workers = 8, persistent_workers = True,
52
                                sampler = train_sampler, pin_memory = True,
53
                                prefetch_factor = 4)
54
55
    eval_metadata = read_csv(filepath_or_buffer = DIRECTORY + 'eval_dataset_info.csv')
56
    eval_dataset = CHD_Dataset(metadata = eval_metadata, directory = DIRECTORY, adjacency = adjacency)
57
    eval_sampler = DistributedSampler(eval_dataset, shuffle = True)
58
    eval_dataloader = DataLoader(dataset = eval_dataset, batch_size = 4,
59
                                num_workers = 8, persistent_workers = True,
60
                                sampler = eval_sampler, pin_memory = True,
61
                                prefetch_factor = 4)
62
63
    loss_module = CrossEntropyLoss(weight = FloatTensor([48538./38830., 48538./1387.,
64
                                                        48538./1387., 48538./1387.,
65
                                                        48538./1387., 48538./1387.,
66
                                                        48538./1387., 48538./1386.]).to(local_rank))
67
    optimizer = ZeroRedundancyOptimizer(model.parameters(), Adam, amsgrad = True)
68
69
    writer = SummaryWriter('GNN_Experiment')
70
71
    for epoch in range(EPOCHS):
72
        train_metrics = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
73
                               0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
74
        eval_metrics = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
75
                              0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
76
        train_loss = 0.0
77
        eval_loss = 0.0
78
        
79
        if global_rank == 0:
80
            print('--------------------------------------------------')
81
            print('Epoch ', epoch + 1)
82
83
        model.train()
84
        train_sampler.set_epoch(epoch = epoch + 1)
85
        eval_sampler.set_epoch(epoch = epoch + 1)
86
        barrier()
87
88
        for batch in tqdm(train_dataloader, total = len(train_dataloader)):
89
            optimizer.zero_grad()
90
            batch = batch.to(local_rank)
91
            preds = model(batch.x, batch.edge_index)
92
            _, pred_labels = max(preds, dim = 1)
93
94
            loss = loss_module(preds, batch.y)
95
            loss.backward()
96
            optimizer.step()
97
98
            metrics = Accuracy_Util(batch.y, pred_labels)
99
            mask = argwhere(isin(metrics, -1)^True)
100
            train_metrics[mask] += metrics[mask]
101
            mask += 10
102
            train_metrics[mask] += 1
103
        
104
        model.eval()
105
106
        for batch in tqdm(eval_dataloader, total = len(eval_dataloader)):
107
            batch = batch.to(local_rank)
108
            preds = model(batch.x, batch.edge_index)
109
            _, pred_labels = max(preds, dim = 1)
110
111
            loss = loss_module(preds, batch.y)
112
            metrics = Accuracy_Util(batch.y, pred_labels)
113
            mask = argwhere(isin(metrics, -1)^True)
114
            eval_metrics[mask] += metrics[mask]
115
            mask += 10
116
            eval_metrics[mask] += 1
117
        
118
        if global_rank == 0:
119
            print('-----TRAINING METRICS-----')
120
            print('Loss: ', train_loss / float(len(train_dataloader)))
121
            writer.add_scalar('train_loss', train_loss / float(len(train_dataloader)),
122
                              global_step = epoch + 1)
123
            for i in range(0, 10):
124
                print('Accuracy ', i, ': ', train_metrics[i] / float(train_metrics[i+10]))
125
                writer.add_scalar('train_accuracy_' + str(i), train_metrics[i] / float(train_metrics[i+10]),
126
                                  global_step = epoch + 1)
127
            print('-----EVALUATION METRICS-----')
128
            print('Loss: ', eval_loss / float(len(eval_dataloader)))
129
            writer.add_scalar('eval_loss', eval_loss / float(len(eval_dataloader)),
130
                              global_step = epoch + 1)
131
            for i in range(0, 10):
132
                print('Accuracy ', i, ': ', eval_metrics[i] / float(eval_metrics[i+10]))
133
                writer.add_scalar('eval_accuracy_' + str(i), eval_metrics[i] / float(eval_metrics[i+10]),
134
                                  global_step = epoch + 1)
135
            
136
            checkpoint_path = 'MODELS/gnn_' + str(epoch + 1) + '.checkpoint'
137
            save(model.module.state_dict(), checkpoint_path)
138
139
        barrier()
140
        map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank}
141
        model.load_state_dict(load(checkpoint_path, map_location = map_location))
142
143
if __name__ == '__main__':
144
    main()