Diff of /analysis/visualization.py [000000] .. [607087]

Switch to side-by-side view

--- a
+++ b/analysis/visualization.py
@@ -0,0 +1,472 @@
+import torch
+import numpy as np
+import os
+import glob
+import random
+import matplotlib
+import imageio
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from analysis.molecule_builder import get_bond_order
+
+
+##############
+### Files ####
+###########-->
+
+
+def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0,
+                  name='molecule', batch_mask=None):
+    try:
+        os.makedirs(path)
+    except OSError:
+        pass
+
+    if batch_mask is None:
+        batch_mask = torch.zeros(len(one_hot))
+
+    for batch_i in torch.unique(batch_mask):
+        cur_batch_mask = (batch_mask == batch_i)
+        n_atoms = int(torch.sum(cur_batch_mask).item())
+        f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w")
+        f.write("%d\n\n" % n_atoms)
+        atoms = torch.argmax(one_hot[cur_batch_mask], dim=1)
+        batch_pos = positions[cur_batch_mask]
+        for atom_i in range(n_atoms):
+            atom = atoms[atom_i]
+            atom = atom_decoder[atom]
+            f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2]))
+        f.close()
+
+
+def load_molecule_xyz(file, dataset_info):
+    with open(file, encoding='utf8') as f:
+        n_atoms = int(f.readline())
+        one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder']))
+        positions = torch.zeros(n_atoms, 3)
+        f.readline()
+        atoms = f.readlines()
+        for i in range(n_atoms):
+            atom = atoms[i].split(' ')
+            atom_type = atom[0]
+            one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1
+            position = torch.Tensor([float(e) for e in atom[1:]])
+            positions[i, :] = position
+        return positions, one_hot
+
+
+def load_xyz_files(path, shuffle=True):
+    files = glob.glob(path + "/*.xyz")
+    if shuffle:
+        random.shuffle(files)
+    return files
+
+
+# <----########
+### Files ####
+##############
+def draw_sphere(ax, x, y, z, size, color, alpha):
+    u = np.linspace(0, 2 * np.pi, 100)
+    v = np.linspace(0, np.pi, 100)
+
+    xs = size * np.outer(np.cos(u), np.sin(v))
+    ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8  # Correct for matplotlib.
+    zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
+    # for i in range(2):
+    #    ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5),  rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5)
+
+    ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color,
+                    linewidth=0,
+                    alpha=alpha)
+    # # calculate vectors for "vertical" circle
+    # a = np.array([-np.sin(elev / 180 * np.pi), 0, np.cos(elev / 180 * np.pi)])
+    # b = np.array([0, 1, 0])
+    # b = b * np.cos(rot) + np.cross(a, b) * np.sin(rot) + a * np.dot(a, b) * (
+    #             1 - np.cos(rot))
+    # ax.plot(np.sin(u), np.cos(u), 0, color='k', linestyle='dashed')
+    # horiz_front = np.linspace(0, np.pi, 100)
+    # ax.plot(np.sin(horiz_front), np.cos(horiz_front), 0, color='k')
+    # vert_front = np.linspace(np.pi / 2, 3 * np.pi / 2, 100)
+    # ax.plot(a[0] * np.sin(u) + b[0] * np.cos(u), b[1] * np.cos(u),
+    #         a[2] * np.sin(u) + b[2] * np.cos(u), color='k', linestyle='dashed')
+    # ax.plot(a[0] * np.sin(vert_front) + b[0] * np.cos(vert_front),
+    #         b[1] * np.cos(vert_front),
+    #         a[2] * np.sin(vert_front) + b[2] * np.cos(vert_front), color='k')
+    #
+    # ax.view_init(elev=elev, azim=0)
+
+
+def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color,
+                  dataset_info):
+    # draw_sphere(ax, 0, 0, 0, 1)
+    # draw_sphere(ax, 1, 1, 1, 1)
+
+    x = positions[:, 0]
+    y = positions[:, 1]
+    z = positions[:, 2]
+    # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
+
+    # ax.set_facecolor((1.0, 0.47, 0.42))
+    colors_dic = np.array(dataset_info['colors_dic'])
+    radius_dic = np.array(dataset_info['radius_dic'])
+    area_dic = 1500 * radius_dic ** 2
+    # areas_dic = sizes_dic * sizes_dic * 3.1416
+
+    areas = area_dic[atom_type]
+    radii = radius_dic[atom_type]
+    colors = colors_dic[atom_type]
+
+    if spheres_3d:
+        for i, j, k, s, c in zip(x, y, z, radii, colors):
+            draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha)
+    else:
+        ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha,
+                   c=colors)  # , linewidths=2, edgecolors='#FFFFFF')
+
+    for i in range(len(x)):
+        for j in range(i + 1, len(x)):
+            p1 = np.array([x[i], y[i], z[i]])
+            p2 = np.array([x[j], y[j], z[j]])
+            dist = np.sqrt(np.sum((p1 - p2) ** 2))
+            atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \
+                           dataset_info['atom_decoder'][atom_type[j]]
+            s = (atom_type[i], atom_type[j])
+
+            draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]],
+                                           dataset_info['atom_decoder'][s[1]],
+                                           dist)
+            line_width = 2
+
+            draw_edge = draw_edge_int > 0
+            if draw_edge:
+                if draw_edge_int == 4:
+                    linewidth_factor = 1.5
+                else:
+                    # linewidth_factor = draw_edge_int  # Prop to number of
+                    # edges.
+                    linewidth_factor = 1
+                ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
+                        linewidth=line_width * linewidth_factor,
+                        c=hex_bg_color, alpha=alpha)
+
+
+def plot_data3d(positions, atom_type, dataset_info, camera_elev=0,
+                camera_azim=0, save_path=None, spheres_3d=False,
+                bg='black', alpha=1.):
+    black = (0, 0, 0)
+    white = (1, 1, 1)
+    hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666'
+
+    from mpl_toolkits.mplot3d import Axes3D
+    fig = plt.figure()
+    ax = fig.add_subplot(projection='3d')
+    ax.set_aspect('auto')
+    ax.view_init(elev=camera_elev, azim=camera_azim)
+    if bg == 'black':
+        ax.set_facecolor(black)
+    else:
+        ax.set_facecolor(white)
+    # ax.xaxis.pane.set_edgecolor('#D0D0D0')
+    ax.xaxis.pane.set_alpha(0)
+    ax.yaxis.pane.set_alpha(0)
+    ax.zaxis.pane.set_alpha(0)
+    ax._axis3don = False
+
+    if bg == 'black':
+        ax.w_xaxis.line.set_color("black")
+    else:
+        ax.w_xaxis.line.set_color("white")
+
+    plot_molecule(ax, positions, atom_type, alpha, spheres_3d,
+                  hex_bg_color, dataset_info)
+
+    # if 'qm9' in dataset_info['name']:
+    max_value = positions.abs().max().item()
+
+    # axis_lim = 3.2
+    axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
+    ax.set_xlim(-axis_lim, axis_lim)
+    ax.set_ylim(-axis_lim, axis_lim)
+    ax.set_zlim(-axis_lim, axis_lim)
+    # elif dataset_info['name'] == 'geom':
+    #     max_value = positions.abs().max().item()
+    #
+    #     # axis_lim = 3.2
+    #     axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
+    #     ax.set_xlim(-axis_lim, axis_lim)
+    #     ax.set_ylim(-axis_lim, axis_lim)
+    #     ax.set_zlim(-axis_lim, axis_lim)
+    # elif dataset_info['name'] == 'pdbbind':
+    #     max_value = positions.abs().max().item()
+    #
+    #     # axis_lim = 3.2
+    #     axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
+    #     ax.set_xlim(-axis_lim, axis_lim)
+    #     ax.set_ylim(-axis_lim, axis_lim)
+    #     ax.set_zlim(-axis_lim, axis_lim)
+    # else:
+    #     raise ValueError(dataset_info['name'])
+
+    dpi = 120 if spheres_3d else 50
+
+    if save_path is not None:
+        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
+
+        if spheres_3d:
+            img = imageio.imread(save_path)
+            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
+            imageio.imsave(save_path, img_brighter)
+    else:
+        plt.show()
+    plt.close()
+
+
+def plot_data3d_uncertainty(
+        all_positions, all_atom_types, dataset_info, camera_elev=0,
+        camera_azim=0,
+        save_path=None, spheres_3d=False, bg='black', alpha=1.):
+    black = (0, 0, 0)
+    white = (1, 1, 1)
+    hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666'
+
+    from mpl_toolkits.mplot3d import Axes3D
+    fig = plt.figure()
+    ax = fig.add_subplot(projection='3d')
+    ax.set_aspect('auto')
+    ax.view_init(elev=camera_elev, azim=camera_azim)
+    if bg == 'black':
+        ax.set_facecolor(black)
+    else:
+        ax.set_facecolor(white)
+    # ax.xaxis.pane.set_edgecolor('#D0D0D0')
+    ax.xaxis.pane.set_alpha(0)
+    ax.yaxis.pane.set_alpha(0)
+    ax.zaxis.pane.set_alpha(0)
+    ax._axis3don = False
+
+    if bg == 'black':
+        ax.w_xaxis.line.set_color("black")
+    else:
+        ax.w_xaxis.line.set_color("white")
+
+    for i in range(len(all_positions)):
+        positions = all_positions[i]
+        atom_type = all_atom_types[i]
+        plot_molecule(ax, positions, atom_type, alpha, spheres_3d,
+                      hex_bg_color, dataset_info)
+
+    if 'qm9' in dataset_info['name']:
+        max_value = all_positions[0].abs().max().item()
+
+        # axis_lim = 3.2
+        axis_lim = min(40, max(max_value + 0.3, 3.2))
+        ax.set_xlim(-axis_lim, axis_lim)
+        ax.set_ylim(-axis_lim, axis_lim)
+        ax.set_zlim(-axis_lim, axis_lim)
+    elif dataset_info['name'] == 'geom':
+        max_value = all_positions[0].abs().max().item()
+
+        # axis_lim = 3.2
+        axis_lim = min(40, max(max_value / 2 + 0.3, 3.2))
+        ax.set_xlim(-axis_lim, axis_lim)
+        ax.set_ylim(-axis_lim, axis_lim)
+        ax.set_zlim(-axis_lim, axis_lim)
+    elif dataset_info['name'] == 'pdbbind':
+        max_value = all_positions[0].abs().max().item()
+
+        # axis_lim = 3.2
+        axis_lim = min(40, max(max_value / 2 + 0.3, 3.2))
+        ax.set_xlim(-axis_lim, axis_lim)
+        ax.set_ylim(-axis_lim, axis_lim)
+        ax.set_zlim(-axis_lim, axis_lim)
+    else:
+        raise ValueError(dataset_info['name'])
+
+    dpi = 120 if spheres_3d else 50
+
+    if save_path is not None:
+        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
+
+        if spheres_3d:
+            img = imageio.imread(save_path)
+            img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
+            imageio.imsave(save_path, img_brighter)
+    else:
+        plt.show()
+    plt.close()
+
+
+def plot_grid():
+    import matplotlib.pyplot as plt
+    from mpl_toolkits.axes_grid1 import ImageGrid
+
+    im1 = np.arange(100).reshape((10, 10))
+    im2 = im1.T
+    im3 = np.flipud(im1)
+    im4 = np.fliplr(im2)
+
+    fig = plt.figure(figsize=(10., 10.))
+    grid = ImageGrid(fig, 111,  # similar to subplot(111)
+                     nrows_ncols=(6, 6),  # creates 2x2 grid of axes
+                     axes_pad=0.1,  # pad between axes in inch.
+                     )
+
+    for ax, im in zip(grid, [im1, im2, im3, im4]):
+        # Iterating over the grid returns the Axes.
+
+        ax.imshow(im)
+
+    plt.show()
+
+
+def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False):
+    files = load_xyz_files(path)[0:max_num]
+    for file in files:
+        positions, one_hot = load_molecule_xyz(file, dataset_info)
+        atom_type = torch.argmax(one_hot, dim=1).numpy()
+        dists = torch.cdist(positions.unsqueeze(0),
+                            positions.unsqueeze(0)).squeeze(0)
+        dists = dists[dists > 0]
+        # print("Average distance between atoms", dists.mean().item())
+        plot_data3d(positions, atom_type, dataset_info=dataset_info,
+                    save_path=file[:-4] + '.png',
+                    spheres_3d=spheres_3d)
+
+        if wandb is not None:
+            path = file[:-4] + '.png'
+            # Log image(s)
+            im = plt.imread(path)
+            wandb.log({'molecule': [wandb.Image(im, caption=path)]})
+
+
+def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False,
+                    mode="chain"):
+    files = load_xyz_files(path)
+    files = sorted(files)
+    save_paths = []
+
+    for i in range(len(files)):
+        file = files[i]
+
+        positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info)
+
+        atom_type = torch.argmax(one_hot, dim=1).numpy()
+        fn = file[:-4] + '.png'
+        plot_data3d(positions, atom_type, dataset_info=dataset_info,
+                    save_path=fn, spheres_3d=spheres_3d, alpha=1.0)
+        save_paths.append(fn)
+
+    imgs = [imageio.imread(fn) for fn in save_paths]
+    dirname = os.path.dirname(save_paths[0])
+    gif_path = dirname + '/output.gif'
+    print(f'Creating gif with {len(imgs)} images')
+    # Add the last frame 10 times so that the final result remains temporally.
+    # imgs.extend([imgs[-1]] * 10)
+    imageio.mimsave(gif_path, imgs, subrectangles=True)
+
+    if wandb is not None:
+        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
+
+
+def visualize_chain_uncertainty(
+        path, dataset_info, wandb=None, spheres_3d=False, mode="chain"):
+    files = load_xyz_files(path)
+    files = sorted(files)
+    save_paths = []
+
+    for i in range(len(files)):
+        if i + 2 == len(files):
+            break
+
+        file = files[i]
+        file2 = files[i + 1]
+        file3 = files[i + 2]
+
+        positions, one_hot, _ = load_molecule_xyz(file,
+                                                  dataset_info=dataset_info)
+        positions2, one_hot2, _ = load_molecule_xyz(
+            file2, dataset_info=dataset_info)
+        positions3, one_hot3, _ = load_molecule_xyz(
+            file3, dataset_info=dataset_info)
+
+        all_positions = torch.stack([positions, positions2, positions3], dim=0)
+        one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0)
+
+        all_atom_type = torch.argmax(one_hot, dim=2).numpy()
+        fn = file[:-4] + '.png'
+        plot_data3d_uncertainty(
+            all_positions, all_atom_type, dataset_info=dataset_info,
+            save_path=fn, spheres_3d=spheres_3d, alpha=0.5)
+        save_paths.append(fn)
+
+    imgs = [imageio.imread(fn) for fn in save_paths]
+    dirname = os.path.dirname(save_paths[0])
+    gif_path = dirname + '/output.gif'
+    print(f'Creating gif with {len(imgs)} images')
+    # Add the last frame 10 times so that the final result remains temporally.
+    # imgs.extend([imgs[-1]] * 10)
+    imageio.mimsave(gif_path, imgs, subrectangles=True)
+
+    if wandb is not None:
+        wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})
+
+
+if __name__ == '__main__':
+    # plot_grid()
+    import qm9.dataset as dataset
+    from configs.datasets_config import qm9_with_h, geom_with_h
+
+    matplotlib.use('macosx')
+
+    task = "visualize_molecules"
+    task_dataset = 'geom'
+
+    if task_dataset == 'qm9':
+        dataset_info = qm9_with_h
+
+
+        class Args:
+            batch_size = 1
+            num_workers = 0
+            filter_n_atoms = None
+            datadir = 'qm9/temp'
+            dataset = 'qm9'
+            remove_h = False
+
+
+        cfg = Args()
+
+        dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg)
+
+        for i, data in enumerate(dataloaders['train']):
+            positions = data['positions'].view(-1, 3)
+            positions_centered = positions - positions.mean(dim=0, keepdim=True)
+            one_hot = data['one_hot'].view(-1, 5).type(torch.float32)
+            atom_type = torch.argmax(one_hot, dim=1).numpy()
+
+            plot_data3d(
+                positions_centered, atom_type, dataset_info=dataset_info,
+                spheres_3d=True)
+
+    elif task_dataset == 'geom':
+        files = load_xyz_files('outputs/data')
+        matplotlib.use('macosx')
+        for file in files:
+            x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h)
+
+            positions = x.view(-1, 3)
+            positions_centered = positions - positions.mean(dim=0, keepdim=True)
+            one_hot = one_hot.view(-1, 16).type(torch.float32)
+            atom_type = torch.argmax(one_hot, dim=1).numpy()
+
+            mask = (x == 0).sum(1) != 3
+            positions_centered = positions_centered[mask]
+            atom_type = atom_type[mask]
+
+            plot_data3d(
+                positions_centered, atom_type, dataset_info=geom_with_h,
+                spheres_3d=False)
+
+    else:
+        raise ValueError(dataset)