|
a |
|
b/Train_Script.py |
|
|
1 |
from torch_geometric import __version__ as pyg_version |
|
|
2 |
from torch_geometric.loader import DataLoader |
|
|
3 |
from torch import __version__ as torch_version |
|
|
4 |
from torch import FloatTensor, max, save, load |
|
|
5 |
from torch.cuda import is_available, set_device, empty_cache, device_count |
|
|
6 |
from torch.distributed import init_process_group, barrier |
|
|
7 |
from torch.distributed.optim.zero_redundancy_optimizer import ZeroRedundancyOptimizer |
|
|
8 |
from torch.utils.data.distributed import DistributedSampler |
|
|
9 |
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
10 |
from torch.distributed.elastic.multiprocessing.errors import record |
|
|
11 |
from torch.utils.tensorboard.writer import SummaryWriter |
|
|
12 |
from torch.optim import Adam |
|
|
13 |
from torch.nn import CrossEntropyLoss |
|
|
14 |
from pandas import read_csv |
|
|
15 |
from os import environ |
|
|
16 |
from tqdm import tqdm |
|
|
17 |
from numpy import array, isin, argwhere |
|
|
18 |
|
|
|
19 |
from Network import CHD_GNN |
|
|
20 |
from Utilities import CHD_Dataset, __Load_Adjacency__ |
|
|
21 |
from Metrics import Accuracy_Util |
|
|
22 |
|
|
|
23 |
@record |
|
|
24 |
def main(): |
|
|
25 |
DIRECTORY = '/home/sojo/Documents/ImageCHD/ImageCHD_dataset/' |
|
|
26 |
EPOCHS = 200 |
|
|
27 |
|
|
|
28 |
init_process_group('nccl') |
|
|
29 |
environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' |
|
|
30 |
local_rank = int(environ['LOCAL_RANK']) |
|
|
31 |
global_rank = int(environ['RANK']) |
|
|
32 |
batch_size = int(environ['WORLD_SIZE']) |
|
|
33 |
|
|
|
34 |
if global_rank == 0: |
|
|
35 |
print('PyG version: ', pyg_version) |
|
|
36 |
print('Torch version: ', torch_version) |
|
|
37 |
print('GPU available: ', is_available()) |
|
|
38 |
print('GPU count: ', device_count()) |
|
|
39 |
|
|
|
40 |
set_device(local_rank) |
|
|
41 |
empty_cache() |
|
|
42 |
|
|
|
43 |
model = CHD_GNN().to('cuda:' + str(local_rank)) |
|
|
44 |
model = DDP(model, device_ids = [local_rank]) |
|
|
45 |
adjacency = __Load_Adjacency__(DIRECTORY + 'ADJACENCY/') |
|
|
46 |
|
|
|
47 |
train_metadata = read_csv(filepath_or_buffer = DIRECTORY + 'train_dataset_info.csv') |
|
|
48 |
train_dataset = CHD_Dataset(metadata = train_metadata, directory = DIRECTORY, adjacency = adjacency) |
|
|
49 |
train_sampler = DistributedSampler(train_dataset, shuffle = True) |
|
|
50 |
train_dataloader = DataLoader(dataset = train_dataset, batch_size = 4, |
|
|
51 |
num_workers = 8, persistent_workers = True, |
|
|
52 |
sampler = train_sampler, pin_memory = True, |
|
|
53 |
prefetch_factor = 4) |
|
|
54 |
|
|
|
55 |
eval_metadata = read_csv(filepath_or_buffer = DIRECTORY + 'eval_dataset_info.csv') |
|
|
56 |
eval_dataset = CHD_Dataset(metadata = eval_metadata, directory = DIRECTORY, adjacency = adjacency) |
|
|
57 |
eval_sampler = DistributedSampler(eval_dataset, shuffle = True) |
|
|
58 |
eval_dataloader = DataLoader(dataset = eval_dataset, batch_size = 4, |
|
|
59 |
num_workers = 8, persistent_workers = True, |
|
|
60 |
sampler = eval_sampler, pin_memory = True, |
|
|
61 |
prefetch_factor = 4) |
|
|
62 |
|
|
|
63 |
loss_module = CrossEntropyLoss(weight = FloatTensor([48538./38830., 48538./1387., |
|
|
64 |
48538./1387., 48538./1387., |
|
|
65 |
48538./1387., 48538./1387., |
|
|
66 |
48538./1387., 48538./1386.]).to(local_rank)) |
|
|
67 |
optimizer = ZeroRedundancyOptimizer(model.parameters(), Adam, amsgrad = True) |
|
|
68 |
|
|
|
69 |
writer = SummaryWriter('GNN_Experiment') |
|
|
70 |
|
|
|
71 |
for epoch in range(EPOCHS): |
|
|
72 |
train_metrics = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., |
|
|
73 |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) |
|
|
74 |
eval_metrics = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., |
|
|
75 |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) |
|
|
76 |
train_loss = 0.0 |
|
|
77 |
eval_loss = 0.0 |
|
|
78 |
|
|
|
79 |
if global_rank == 0: |
|
|
80 |
print('--------------------------------------------------') |
|
|
81 |
print('Epoch ', epoch + 1) |
|
|
82 |
|
|
|
83 |
model.train() |
|
|
84 |
train_sampler.set_epoch(epoch = epoch + 1) |
|
|
85 |
eval_sampler.set_epoch(epoch = epoch + 1) |
|
|
86 |
barrier() |
|
|
87 |
|
|
|
88 |
for batch in tqdm(train_dataloader, total = len(train_dataloader)): |
|
|
89 |
optimizer.zero_grad() |
|
|
90 |
batch = batch.to(local_rank) |
|
|
91 |
preds = model(batch.x, batch.edge_index) |
|
|
92 |
_, pred_labels = max(preds, dim = 1) |
|
|
93 |
|
|
|
94 |
loss = loss_module(preds, batch.y) |
|
|
95 |
loss.backward() |
|
|
96 |
optimizer.step() |
|
|
97 |
|
|
|
98 |
metrics = Accuracy_Util(batch.y, pred_labels) |
|
|
99 |
mask = argwhere(isin(metrics, -1)^True) |
|
|
100 |
train_metrics[mask] += metrics[mask] |
|
|
101 |
mask += 10 |
|
|
102 |
train_metrics[mask] += 1 |
|
|
103 |
|
|
|
104 |
model.eval() |
|
|
105 |
|
|
|
106 |
for batch in tqdm(eval_dataloader, total = len(eval_dataloader)): |
|
|
107 |
batch = batch.to(local_rank) |
|
|
108 |
preds = model(batch.x, batch.edge_index) |
|
|
109 |
_, pred_labels = max(preds, dim = 1) |
|
|
110 |
|
|
|
111 |
loss = loss_module(preds, batch.y) |
|
|
112 |
metrics = Accuracy_Util(batch.y, pred_labels) |
|
|
113 |
mask = argwhere(isin(metrics, -1)^True) |
|
|
114 |
eval_metrics[mask] += metrics[mask] |
|
|
115 |
mask += 10 |
|
|
116 |
eval_metrics[mask] += 1 |
|
|
117 |
|
|
|
118 |
if global_rank == 0: |
|
|
119 |
print('-----TRAINING METRICS-----') |
|
|
120 |
print('Loss: ', train_loss / float(len(train_dataloader))) |
|
|
121 |
writer.add_scalar('train_loss', train_loss / float(len(train_dataloader)), |
|
|
122 |
global_step = epoch + 1) |
|
|
123 |
for i in range(0, 10): |
|
|
124 |
print('Accuracy ', i, ': ', train_metrics[i] / float(train_metrics[i+10])) |
|
|
125 |
writer.add_scalar('train_accuracy_' + str(i), train_metrics[i] / float(train_metrics[i+10]), |
|
|
126 |
global_step = epoch + 1) |
|
|
127 |
print('-----EVALUATION METRICS-----') |
|
|
128 |
print('Loss: ', eval_loss / float(len(eval_dataloader))) |
|
|
129 |
writer.add_scalar('eval_loss', eval_loss / float(len(eval_dataloader)), |
|
|
130 |
global_step = epoch + 1) |
|
|
131 |
for i in range(0, 10): |
|
|
132 |
print('Accuracy ', i, ': ', eval_metrics[i] / float(eval_metrics[i+10])) |
|
|
133 |
writer.add_scalar('eval_accuracy_' + str(i), eval_metrics[i] / float(eval_metrics[i+10]), |
|
|
134 |
global_step = epoch + 1) |
|
|
135 |
|
|
|
136 |
checkpoint_path = 'MODELS/gnn_' + str(epoch + 1) + '.checkpoint' |
|
|
137 |
save(model.module.state_dict(), checkpoint_path) |
|
|
138 |
|
|
|
139 |
barrier() |
|
|
140 |
map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank} |
|
|
141 |
model.load_state_dict(load(checkpoint_path, map_location = map_location)) |
|
|
142 |
|
|
|
143 |
if __name__ == '__main__': |
|
|
144 |
main() |