|
a |
|
b/Test.py |
|
|
1 |
from torch_geometric import __version__ as pyg_version |
|
|
2 |
from torch import __version__ as torch_version |
|
|
3 |
from torch import device, as_tensor, max |
|
|
4 |
from torch.cuda import is_available, can_device_access_peer |
|
|
5 |
from Network import CHD_GNN |
|
|
6 |
from Utilities import CHD_Dataset |
|
|
7 |
from pandas import read_csv |
|
|
8 |
from Graph_Conversion import Convert_To_Image |
|
|
9 |
from matplotlib import pyplot as plt |
|
|
10 |
from torch.distributed import init_process_group |
|
|
11 |
from os import environ |
|
|
12 |
|
|
|
13 |
DIRECTORY = '/home/sojo/Documents/ImageCHD/ImageCHD_dataset/' |
|
|
14 |
|
|
|
15 |
print(pyg_version) |
|
|
16 |
print(torch_version) |
|
|
17 |
print(is_available()) |
|
|
18 |
print(can_device_access_peer(device('cuda:1'), device('cuda:0'))) |
|
|
19 |
|
|
|
20 |
init_process_group('nccl') |
|
|
21 |
local_rank = int(environ['LOCAL_RANK']) |
|
|
22 |
global_rank = int(environ['RANK']) |
|
|
23 |
batch_size = int(environ['WORLD_SIZE']) |
|
|
24 |
|
|
|
25 |
if global_rank == 0: |
|
|
26 |
print('PyG version: ', pyg_version) |
|
|
27 |
print('Torch version: ', torch_version) |
|
|
28 |
print('GPU available: ', is_available()) |
|
|
29 |
print(batch_size) |
|
|
30 |
|
|
|
31 |
# gpu = device('cuda:0') |
|
|
32 |
# print(gpu) |
|
|
33 |
# gpu = device('cuda:1') |
|
|
34 |
# print(gpu) |
|
|
35 |
# testing = CHD_GNN().to(gpu) |
|
|
36 |
# metadata = read_csv(filepath_or_buffer = DIRECTORY + 'train_dataset_info.csv') |
|
|
37 |
# dataset = CHD_Dataset(metadata = metadata, directory = DIRECTORY) |
|
|
38 |
# sample = dataset.get(76) |
|
|
39 |
|
|
|
40 |
# print(sample.x.type()) |
|
|
41 |
# print(sample.edge_index.type()) |
|
|
42 |
# print(sample.y.type()) |
|
|
43 |
|
|
|
44 |
# print(sample.x[0][0].type()) |
|
|
45 |
# print(sample.edge_index[0][0].type()) |
|
|
46 |
# print(sample.y[0][0].type()) |
|
|
47 |
|
|
|
48 |
# out = testing(sample.x, sample.edge_index) |
|
|
49 |
# print(out.shape) |
|
|
50 |
# print(out.type()) |
|
|
51 |
# _, label = max(out, dim = 1) |
|
|
52 |
# print(label) |
|
|
53 |
# print(label.shape) |
|
|
54 |
# print(label.type()) |
|
|
55 |
|
|
|
56 |
# result = Convert_To_Image(label, sample.adj_count) |
|
|
57 |
# plt.imshow(result, cmap = 'gray') |
|
|
58 |
# plt.show() |