Diff of /Train_Script.py [000000] .. [b52eda]

Switch to side-by-side view

--- a
+++ b/Train_Script.py
@@ -0,0 +1,144 @@
+from torch_geometric import __version__ as pyg_version
+from torch_geometric.loader import DataLoader
+from torch import __version__ as torch_version
+from torch import FloatTensor, max, save, load
+from torch.cuda import is_available, set_device, empty_cache, device_count
+from torch.distributed import init_process_group, barrier
+from torch.distributed.optim.zero_redundancy_optimizer import ZeroRedundancyOptimizer
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.elastic.multiprocessing.errors import record
+from torch.utils.tensorboard.writer import SummaryWriter
+from torch.optim import Adam
+from torch.nn import CrossEntropyLoss
+from pandas import read_csv
+from os import environ
+from tqdm import tqdm
+from numpy import array, isin, argwhere
+
+from Network import CHD_GNN
+from Utilities import CHD_Dataset, __Load_Adjacency__
+from Metrics import Accuracy_Util
+
+@record
+def main():
+    DIRECTORY = '/home/sojo/Documents/ImageCHD/ImageCHD_dataset/'
+    EPOCHS = 200
+
+    init_process_group('nccl')
+    environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
+    local_rank = int(environ['LOCAL_RANK'])
+    global_rank = int(environ['RANK'])
+    batch_size = int(environ['WORLD_SIZE'])
+
+    if global_rank == 0:
+        print('PyG version: ', pyg_version)
+        print('Torch version: ', torch_version)
+        print('GPU available: ', is_available())
+        print('GPU count: ', device_count())
+
+    set_device(local_rank)
+    empty_cache()
+
+    model = CHD_GNN().to('cuda:' + str(local_rank))
+    model = DDP(model, device_ids = [local_rank])
+    adjacency = __Load_Adjacency__(DIRECTORY + 'ADJACENCY/')
+
+    train_metadata = read_csv(filepath_or_buffer = DIRECTORY + 'train_dataset_info.csv')
+    train_dataset = CHD_Dataset(metadata = train_metadata, directory = DIRECTORY, adjacency = adjacency)
+    train_sampler = DistributedSampler(train_dataset, shuffle = True)
+    train_dataloader = DataLoader(dataset = train_dataset, batch_size = 4,
+                                num_workers = 8, persistent_workers = True,
+                                sampler = train_sampler, pin_memory = True,
+                                prefetch_factor = 4)
+
+    eval_metadata = read_csv(filepath_or_buffer = DIRECTORY + 'eval_dataset_info.csv')
+    eval_dataset = CHD_Dataset(metadata = eval_metadata, directory = DIRECTORY, adjacency = adjacency)
+    eval_sampler = DistributedSampler(eval_dataset, shuffle = True)
+    eval_dataloader = DataLoader(dataset = eval_dataset, batch_size = 4,
+                                num_workers = 8, persistent_workers = True,
+                                sampler = eval_sampler, pin_memory = True,
+                                prefetch_factor = 4)
+
+    loss_module = CrossEntropyLoss(weight = FloatTensor([48538./38830., 48538./1387.,
+                                                        48538./1387., 48538./1387.,
+                                                        48538./1387., 48538./1387.,
+                                                        48538./1387., 48538./1386.]).to(local_rank))
+    optimizer = ZeroRedundancyOptimizer(model.parameters(), Adam, amsgrad = True)
+
+    writer = SummaryWriter('GNN_Experiment')
+
+    for epoch in range(EPOCHS):
+        train_metrics = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+                               0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
+        eval_metrics = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
+                              0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
+        train_loss = 0.0
+        eval_loss = 0.0
+        
+        if global_rank == 0:
+            print('--------------------------------------------------')
+            print('Epoch ', epoch + 1)
+
+        model.train()
+        train_sampler.set_epoch(epoch = epoch + 1)
+        eval_sampler.set_epoch(epoch = epoch + 1)
+        barrier()
+
+        for batch in tqdm(train_dataloader, total = len(train_dataloader)):
+            optimizer.zero_grad()
+            batch = batch.to(local_rank)
+            preds = model(batch.x, batch.edge_index)
+            _, pred_labels = max(preds, dim = 1)
+
+            loss = loss_module(preds, batch.y)
+            loss.backward()
+            optimizer.step()
+
+            metrics = Accuracy_Util(batch.y, pred_labels)
+            mask = argwhere(isin(metrics, -1)^True)
+            train_metrics[mask] += metrics[mask]
+            mask += 10
+            train_metrics[mask] += 1
+        
+        model.eval()
+
+        for batch in tqdm(eval_dataloader, total = len(eval_dataloader)):
+            batch = batch.to(local_rank)
+            preds = model(batch.x, batch.edge_index)
+            _, pred_labels = max(preds, dim = 1)
+
+            loss = loss_module(preds, batch.y)
+            metrics = Accuracy_Util(batch.y, pred_labels)
+            mask = argwhere(isin(metrics, -1)^True)
+            eval_metrics[mask] += metrics[mask]
+            mask += 10
+            eval_metrics[mask] += 1
+        
+        if global_rank == 0:
+            print('-----TRAINING METRICS-----')
+            print('Loss: ', train_loss / float(len(train_dataloader)))
+            writer.add_scalar('train_loss', train_loss / float(len(train_dataloader)),
+                              global_step = epoch + 1)
+            for i in range(0, 10):
+                print('Accuracy ', i, ': ', train_metrics[i] / float(train_metrics[i+10]))
+                writer.add_scalar('train_accuracy_' + str(i), train_metrics[i] / float(train_metrics[i+10]),
+                                  global_step = epoch + 1)
+            print('-----EVALUATION METRICS-----')
+            print('Loss: ', eval_loss / float(len(eval_dataloader)))
+            writer.add_scalar('eval_loss', eval_loss / float(len(eval_dataloader)),
+                              global_step = epoch + 1)
+            for i in range(0, 10):
+                print('Accuracy ', i, ': ', eval_metrics[i] / float(eval_metrics[i+10]))
+                writer.add_scalar('eval_accuracy_' + str(i), eval_metrics[i] / float(eval_metrics[i+10]),
+                                  global_step = epoch + 1)
+            
+            checkpoint_path = 'MODELS/gnn_' + str(epoch + 1) + '.checkpoint'
+            save(model.module.state_dict(), checkpoint_path)
+
+        barrier()
+        map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank}
+        model.load_state_dict(load(checkpoint_path, map_location = map_location))
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file