Diff of /analysis/visualization.py [000000] .. [607087]

Switch to unified view

a b/analysis/visualization.py
1
import torch
2
import numpy as np
3
import os
4
import glob
5
import random
6
import matplotlib
7
import imageio
8
9
matplotlib.use('Agg')
10
import matplotlib.pyplot as plt
11
from analysis.molecule_builder import get_bond_order
12
13
14
##############
15
### Files ####
16
###########-->
17
18
19
def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0,
20
                  name='molecule', batch_mask=None):
21
    try:
22
        os.makedirs(path)
23
    except OSError:
24
        pass
25
26
    if batch_mask is None:
27
        batch_mask = torch.zeros(len(one_hot))
28
29
    for batch_i in torch.unique(batch_mask):
30
        cur_batch_mask = (batch_mask == batch_i)
31
        n_atoms = int(torch.sum(cur_batch_mask).item())
32
        f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w")
33
        f.write("%d\n\n" % n_atoms)
34
        atoms = torch.argmax(one_hot[cur_batch_mask], dim=1)
35
        batch_pos = positions[cur_batch_mask]
36
        for atom_i in range(n_atoms):
37
            atom = atoms[atom_i]
38
            atom = atom_decoder[atom]
39
            f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2]))
40
        f.close()
41
42
43
def load_molecule_xyz(file, dataset_info):
44
    with open(file, encoding='utf8') as f:
45
        n_atoms = int(f.readline())
46
        one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder']))
47
        positions = torch.zeros(n_atoms, 3)
48
        f.readline()
49
        atoms = f.readlines()
50
        for i in range(n_atoms):
51
            atom = atoms[i].split(' ')
52
            atom_type = atom[0]
53
            one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1
54
            position = torch.Tensor([float(e) for e in atom[1:]])
55
            positions[i, :] = position
56
        return positions, one_hot
57
58
59
def load_xyz_files(path, shuffle=True):
60
    files = glob.glob(path + "/*.xyz")
61
    if shuffle:
62
        random.shuffle(files)
63
    return files
64
65
66
# <----########
67
### Files ####
68
##############
69
def draw_sphere(ax, x, y, z, size, color, alpha):
70
    u = np.linspace(0, 2 * np.pi, 100)
71
    v = np.linspace(0, np.pi, 100)
72
73
    xs = size * np.outer(np.cos(u), np.sin(v))
74
    ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8  # Correct for matplotlib.
75
    zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
76
    # for i in range(2):
77
    #    ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5),  rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5)
78
79
    ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color,
80
                    linewidth=0,
81
                    alpha=alpha)
82
    # # calculate vectors for "vertical" circle
83
    # a = np.array([-np.sin(elev / 180 * np.pi), 0, np.cos(elev / 180 * np.pi)])
84
    # b = np.array([0, 1, 0])
85
    # b = b * np.cos(rot) + np.cross(a, b) * np.sin(rot) + a * np.dot(a, b) * (
86
    #             1 - np.cos(rot))
87
    # ax.plot(np.sin(u), np.cos(u), 0, color='k', linestyle='dashed')
88
    # horiz_front = np.linspace(0, np.pi, 100)
89
    # ax.plot(np.sin(horiz_front), np.cos(horiz_front), 0, color='k')
90
    # vert_front = np.linspace(np.pi / 2, 3 * np.pi / 2, 100)
91
    # ax.plot(a[0] * np.sin(u) + b[0] * np.cos(u), b[1] * np.cos(u),
92
    #         a[2] * np.sin(u) + b[2] * np.cos(u), color='k', linestyle='dashed')
93
    # ax.plot(a[0] * np.sin(vert_front) + b[0] * np.cos(vert_front),
94
    #         b[1] * np.cos(vert_front),
95
    #         a[2] * np.sin(vert_front) + b[2] * np.cos(vert_front), color='k')
96
    #
97
    # ax.view_init(elev=elev, azim=0)
98
99
100
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color,
101
                  dataset_info):
102
    # draw_sphere(ax, 0, 0, 0, 1)
103
    # draw_sphere(ax, 1, 1, 1, 1)
104
105
    x = positions[:, 0]
106
    y = positions[:, 1]
107
    z = positions[:, 2]
108
    # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
109
110
    # ax.set_facecolor((1.0, 0.47, 0.42))
111
    colors_dic = np.array(dataset_info['colors_dic'])
112
    radius_dic = np.array(dataset_info['radius_dic'])
113
    area_dic = 1500 * radius_dic ** 2
114
    # areas_dic = sizes_dic * sizes_dic * 3.1416
115
116
    areas = area_dic[atom_type]
117
    radii = radius_dic[atom_type]
118
    colors = colors_dic[atom_type]
119
120
    if spheres_3d:
121
        for i, j, k, s, c in zip(x, y, z, radii, colors):
122
            draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha)
123
    else:
124
        ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha,
125
                   c=colors)  # , linewidths=2, edgecolors='#FFFFFF')
126
127
    for i in range(len(x)):
128
        for j in range(i + 1, len(x)):
129
            p1 = np.array([x[i], y[i], z[i]])
130
            p2 = np.array([x[j], y[j], z[j]])
131
            dist = np.sqrt(np.sum((p1 - p2) ** 2))
132
            atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \
133
                           dataset_info['atom_decoder'][atom_type[j]]
134
            s = (atom_type[i], atom_type[j])
135
136
            draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]],
137
                                           dataset_info['atom_decoder'][s[1]],
138
                                           dist)
139
            line_width = 2
140
141
            draw_edge = draw_edge_int > 0
142
            if draw_edge:
143
                if draw_edge_int == 4:
144
                    linewidth_factor = 1.5
145
                else:
146
                    # linewidth_factor = draw_edge_int  # Prop to number of
147
                    # edges.
148
                    linewidth_factor = 1
149
                ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
150
                        linewidth=line_width * linewidth_factor,
151
                        c=hex_bg_color, alpha=alpha)
152
153
154
def plot_data3d(positions, atom_type, dataset_info, camera_elev=0,
155
                camera_azim=0, save_path=None, spheres_3d=False,
156
                bg='black', alpha=1.):
157
    black = (0, 0, 0)
158
    white = (1, 1, 1)
159
    hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666'
160
161
    from mpl_toolkits.mplot3d import Axes3D
162
    fig = plt.figure()
163
    ax = fig.add_subplot(projection='3d')
164
    ax.set_aspect('auto')
165
    ax.view_init(elev=camera_elev, azim=camera_azim)
166
    if bg == 'black':
167
        ax.set_facecolor(black)
168
    else:
169
        ax.set_facecolor(white)
170
    # ax.xaxis.pane.set_edgecolor('#D0D0D0')
171
    ax.xaxis.pane.set_alpha(0)
172
    ax.yaxis.pane.set_alpha(0)
173
    ax.zaxis.pane.set_alpha(0)
174
    ax._axis3don = False
175
176
    if bg == 'black':
177
        ax.w_xaxis.line.set_color("black")
178
    else:
179
        ax.w_xaxis.line.set_color("white")
180
181
    plot_molecule(ax, positions, atom_type, alpha, spheres_3d,
182
                  hex_bg_color, dataset_info)
183
184
    # if 'qm9' in dataset_info['name']:
185
    max_value = positions.abs().max().item()
186
187
    # axis_lim = 3.2
188
    axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
189
    ax.set_xlim(-axis_lim, axis_lim)
190
    ax.set_ylim(-axis_lim, axis_lim)
191
    ax.set_zlim(-axis_lim, axis_lim)
192
    # elif dataset_info['name'] == 'geom':
193
    #     max_value = positions.abs().max().item()
194
    #
195
    #     # axis_lim = 3.2
196
    #     axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
197
    #     ax.set_xlim(-axis_lim, axis_lim)
198
    #     ax.set_ylim(-axis_lim, axis_lim)
199
    #     ax.set_zlim(-axis_lim, axis_lim)
200
    # elif dataset_info['name'] == 'pdbbind':
201
    #     max_value = positions.abs().max().item()
202
    #
203
    #     # axis_lim = 3.2
204
    #     axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
205
    #     ax.set_xlim(-axis_lim, axis_lim)
206
    #     ax.set_ylim(-axis_lim, axis_lim)
207
    #     ax.set_zlim(-axis_lim, axis_lim)
208
    # else:
209
    #     raise ValueError(dataset_info['name'])
210
211
    dpi = 120 if spheres_3d else 50
212
213
    if save_path is not None:
214
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
215
216
        if spheres_3d:
217
            img = imageio.imread(save_path)
218
            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
219
            imageio.imsave(save_path, img_brighter)
220
    else:
221
        plt.show()
222
    plt.close()
223
224
225
def plot_data3d_uncertainty(
226
        all_positions, all_atom_types, dataset_info, camera_elev=0,
227
        camera_azim=0,
228
        save_path=None, spheres_3d=False, bg='black', alpha=1.):
229
    black = (0, 0, 0)
230
    white = (1, 1, 1)
231
    hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666'
232
233
    from mpl_toolkits.mplot3d import Axes3D
234
    fig = plt.figure()
235
    ax = fig.add_subplot(projection='3d')
236
    ax.set_aspect('auto')
237
    ax.view_init(elev=camera_elev, azim=camera_azim)
238
    if bg == 'black':
239
        ax.set_facecolor(black)
240
    else:
241
        ax.set_facecolor(white)
242
    # ax.xaxis.pane.set_edgecolor('#D0D0D0')
243
    ax.xaxis.pane.set_alpha(0)
244
    ax.yaxis.pane.set_alpha(0)
245
    ax.zaxis.pane.set_alpha(0)
246
    ax._axis3don = False
247
248
    if bg == 'black':
249
        ax.w_xaxis.line.set_color("black")
250
    else:
251
        ax.w_xaxis.line.set_color("white")
252
253
    for i in range(len(all_positions)):
254
        positions = all_positions[i]
255
        atom_type = all_atom_types[i]
256
        plot_molecule(ax, positions, atom_type, alpha, spheres_3d,
257
                      hex_bg_color, dataset_info)
258
259
    if 'qm9' in dataset_info['name']:
260
        max_value = all_positions[0].abs().max().item()
261
262
        # axis_lim = 3.2
263
        axis_lim = min(40, max(max_value + 0.3, 3.2))
264
        ax.set_xlim(-axis_lim, axis_lim)
265
        ax.set_ylim(-axis_lim, axis_lim)
266
        ax.set_zlim(-axis_lim, axis_lim)
267
    elif dataset_info['name'] == 'geom':
268
        max_value = all_positions[0].abs().max().item()
269
270
        # axis_lim = 3.2
271
        axis_lim = min(40, max(max_value / 2 + 0.3, 3.2))
272
        ax.set_xlim(-axis_lim, axis_lim)
273
        ax.set_ylim(-axis_lim, axis_lim)
274
        ax.set_zlim(-axis_lim, axis_lim)
275
    elif dataset_info['name'] == 'pdbbind':
276
        max_value = all_positions[0].abs().max().item()
277
278
        # axis_lim = 3.2
279
        axis_lim = min(40, max(max_value / 2 + 0.3, 3.2))
280
        ax.set_xlim(-axis_lim, axis_lim)
281
        ax.set_ylim(-axis_lim, axis_lim)
282
        ax.set_zlim(-axis_lim, axis_lim)
283
    else:
284
        raise ValueError(dataset_info['name'])
285
286
    dpi = 120 if spheres_3d else 50
287
288
    if save_path is not None:
289
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
290
291
        if spheres_3d:
292
            img = imageio.imread(save_path)
293
            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
294
            imageio.imsave(save_path, img_brighter)
295
    else:
296
        plt.show()
297
    plt.close()
298
299
300
def plot_grid():
301
    import matplotlib.pyplot as plt
302
    from mpl_toolkits.axes_grid1 import ImageGrid
303
304
    im1 = np.arange(100).reshape((10, 10))
305
    im2 = im1.T
306
    im3 = np.flipud(im1)
307
    im4 = np.fliplr(im2)
308
309
    fig = plt.figure(figsize=(10., 10.))
310
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
311
                     nrows_ncols=(6, 6),  # creates 2x2 grid of axes
312
                     axes_pad=0.1,  # pad between axes in inch.
313
                     )
314
315
    for ax, im in zip(grid, [im1, im2, im3, im4]):
316
        # Iterating over the grid returns the Axes.
317
318
        ax.imshow(im)
319
320
    plt.show()
321
322
323
def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False):
324
    files = load_xyz_files(path)[0:max_num]
325
    for file in files:
326
        positions, one_hot = load_molecule_xyz(file, dataset_info)
327
        atom_type = torch.argmax(one_hot, dim=1).numpy()
328
        dists = torch.cdist(positions.unsqueeze(0),
329
                            positions.unsqueeze(0)).squeeze(0)
330
        dists = dists[dists > 0]
331
        # print("Average distance between atoms", dists.mean().item())
332
        plot_data3d(positions, atom_type, dataset_info=dataset_info,
333
                    save_path=file[:-4] + '.png',
334
                    spheres_3d=spheres_3d)
335
336
        if wandb is not None:
337
            path = file[:-4] + '.png'
338
            # Log image(s)
339
            im = plt.imread(path)
340
            wandb.log({'molecule': [wandb.Image(im, caption=path)]})
341
342
343
def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False,
344
                    mode="chain"):
345
    files = load_xyz_files(path)
346
    files = sorted(files)
347
    save_paths = []
348
349
    for i in range(len(files)):
350
        file = files[i]
351
352
        positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info)
353
354
        atom_type = torch.argmax(one_hot, dim=1).numpy()
355
        fn = file[:-4] + '.png'
356
        plot_data3d(positions, atom_type, dataset_info=dataset_info,
357
                    save_path=fn, spheres_3d=spheres_3d, alpha=1.0)
358
        save_paths.append(fn)
359
360
    imgs = [imageio.imread(fn) for fn in save_paths]
361
    dirname = os.path.dirname(save_paths[0])
362
    gif_path = dirname + '/output.gif'
363
    print(f'Creating gif with {len(imgs)} images')
364
    # Add the last frame 10 times so that the final result remains temporally.
365
    # imgs.extend([imgs[-1]] * 10)
366
    imageio.mimsave(gif_path, imgs, subrectangles=True)
367
368
    if wandb is not None:
369
        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
370
371
372
def visualize_chain_uncertainty(
373
        path, dataset_info, wandb=None, spheres_3d=False, mode="chain"):
374
    files = load_xyz_files(path)
375
    files = sorted(files)
376
    save_paths = []
377
378
    for i in range(len(files)):
379
        if i + 2 == len(files):
380
            break
381
382
        file = files[i]
383
        file2 = files[i + 1]
384
        file3 = files[i + 2]
385
386
        positions, one_hot, _ = load_molecule_xyz(file,
387
                                                  dataset_info=dataset_info)
388
        positions2, one_hot2, _ = load_molecule_xyz(
389
            file2, dataset_info=dataset_info)
390
        positions3, one_hot3, _ = load_molecule_xyz(
391
            file3, dataset_info=dataset_info)
392
393
        all_positions = torch.stack([positions, positions2, positions3], dim=0)
394
        one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0)
395
396
        all_atom_type = torch.argmax(one_hot, dim=2).numpy()
397
        fn = file[:-4] + '.png'
398
        plot_data3d_uncertainty(
399
            all_positions, all_atom_type, dataset_info=dataset_info,
400
            save_path=fn, spheres_3d=spheres_3d, alpha=0.5)
401
        save_paths.append(fn)
402
403
    imgs = [imageio.imread(fn) for fn in save_paths]
404
    dirname = os.path.dirname(save_paths[0])
405
    gif_path = dirname + '/output.gif'
406
    print(f'Creating gif with {len(imgs)} images')
407
    # Add the last frame 10 times so that the final result remains temporally.
408
    # imgs.extend([imgs[-1]] * 10)
409
    imageio.mimsave(gif_path, imgs, subrectangles=True)
410
411
    if wandb is not None:
412
        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
413
414
415
if __name__ == '__main__':
416
    # plot_grid()
417
    import qm9.dataset as dataset
418
    from configs.datasets_config import qm9_with_h, geom_with_h
419
420
    matplotlib.use('macosx')
421
422
    task = "visualize_molecules"
423
    task_dataset = 'geom'
424
425
    if task_dataset == 'qm9':
426
        dataset_info = qm9_with_h
427
428
429
        class Args:
430
            batch_size = 1
431
            num_workers = 0
432
            filter_n_atoms = None
433
            datadir = 'qm9/temp'
434
            dataset = 'qm9'
435
            remove_h = False
436
437
438
        cfg = Args()
439
440
        dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg)
441
442
        for i, data in enumerate(dataloaders['train']):
443
            positions = data['positions'].view(-1, 3)
444
            positions_centered = positions - positions.mean(dim=0, keepdim=True)
445
            one_hot = data['one_hot'].view(-1, 5).type(torch.float32)
446
            atom_type = torch.argmax(one_hot, dim=1).numpy()
447
448
            plot_data3d(
449
                positions_centered, atom_type, dataset_info=dataset_info,
450
                spheres_3d=True)
451
452
    elif task_dataset == 'geom':
453
        files = load_xyz_files('outputs/data')
454
        matplotlib.use('macosx')
455
        for file in files:
456
            x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h)
457
458
            positions = x.view(-1, 3)
459
            positions_centered = positions - positions.mean(dim=0, keepdim=True)
460
            one_hot = one_hot.view(-1, 16).type(torch.float32)
461
            atom_type = torch.argmax(one_hot, dim=1).numpy()
462
463
            mask = (x == 0).sum(1) != 3
464
            positions_centered = positions_centered[mask]
465
            atom_type = atom_type[mask]
466
467
            plot_data3d(
468
                positions_centered, atom_type, dataset_info=geom_with_h,
469
                spheres_3d=False)
470
471
    else:
472
        raise ValueError(dataset)