--- a +++ b/analysis/visualization.py @@ -0,0 +1,472 @@ +import torch +import numpy as np +import os +import glob +import random +import matplotlib +import imageio + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from analysis.molecule_builder import get_bond_order + + +############## +### Files #### +###########--> + + +def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, + name='molecule', batch_mask=None): + try: + os.makedirs(path) + except OSError: + pass + + if batch_mask is None: + batch_mask = torch.zeros(len(one_hot)) + + for batch_i in torch.unique(batch_mask): + cur_batch_mask = (batch_mask == batch_i) + n_atoms = int(torch.sum(cur_batch_mask).item()) + f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w") + f.write("%d\n\n" % n_atoms) + atoms = torch.argmax(one_hot[cur_batch_mask], dim=1) + batch_pos = positions[cur_batch_mask] + for atom_i in range(n_atoms): + atom = atoms[atom_i] + atom = atom_decoder[atom] + f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2])) + f.close() + + +def load_molecule_xyz(file, dataset_info): + with open(file, encoding='utf8') as f: + n_atoms = int(f.readline()) + one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder'])) + positions = torch.zeros(n_atoms, 3) + f.readline() + atoms = f.readlines() + for i in range(n_atoms): + atom = atoms[i].split(' ') + atom_type = atom[0] + one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1 + position = torch.Tensor([float(e) for e in atom[1:]]) + positions[i, :] = position + return positions, one_hot + + +def load_xyz_files(path, shuffle=True): + files = glob.glob(path + "/*.xyz") + if shuffle: + random.shuffle(files) + return files + + +# <----######## +### Files #### +############## +def draw_sphere(ax, x, y, z, size, color, alpha): + u = np.linspace(0, 2 * np.pi, 100) + v = np.linspace(0, np.pi, 100) + + xs = size * np.outer(np.cos(u), np.sin(v)) + ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8 # Correct for matplotlib. + zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) + # for i in range(2): + # ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5), rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5) + + ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, + linewidth=0, + alpha=alpha) + # # calculate vectors for "vertical" circle + # a = np.array([-np.sin(elev / 180 * np.pi), 0, np.cos(elev / 180 * np.pi)]) + # b = np.array([0, 1, 0]) + # b = b * np.cos(rot) + np.cross(a, b) * np.sin(rot) + a * np.dot(a, b) * ( + # 1 - np.cos(rot)) + # ax.plot(np.sin(u), np.cos(u), 0, color='k', linestyle='dashed') + # horiz_front = np.linspace(0, np.pi, 100) + # ax.plot(np.sin(horiz_front), np.cos(horiz_front), 0, color='k') + # vert_front = np.linspace(np.pi / 2, 3 * np.pi / 2, 100) + # ax.plot(a[0] * np.sin(u) + b[0] * np.cos(u), b[1] * np.cos(u), + # a[2] * np.sin(u) + b[2] * np.cos(u), color='k', linestyle='dashed') + # ax.plot(a[0] * np.sin(vert_front) + b[0] * np.cos(vert_front), + # b[1] * np.cos(vert_front), + # a[2] * np.sin(vert_front) + b[2] * np.cos(vert_front), color='k') + # + # ax.view_init(elev=elev, azim=0) + + +def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, + dataset_info): + # draw_sphere(ax, 0, 0, 0, 1) + # draw_sphere(ax, 1, 1, 1, 1) + + x = positions[:, 0] + y = positions[:, 1] + z = positions[:, 2] + # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine + + # ax.set_facecolor((1.0, 0.47, 0.42)) + colors_dic = np.array(dataset_info['colors_dic']) + radius_dic = np.array(dataset_info['radius_dic']) + area_dic = 1500 * radius_dic ** 2 + # areas_dic = sizes_dic * sizes_dic * 3.1416 + + areas = area_dic[atom_type] + radii = radius_dic[atom_type] + colors = colors_dic[atom_type] + + if spheres_3d: + for i, j, k, s, c in zip(x, y, z, radii, colors): + draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha) + else: + ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, + c=colors) # , linewidths=2, edgecolors='#FFFFFF') + + for i in range(len(x)): + for j in range(i + 1, len(x)): + p1 = np.array([x[i], y[i], z[i]]) + p2 = np.array([x[j], y[j], z[j]]) + dist = np.sqrt(np.sum((p1 - p2) ** 2)) + atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \ + dataset_info['atom_decoder'][atom_type[j]] + s = (atom_type[i], atom_type[j]) + + draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]], + dataset_info['atom_decoder'][s[1]], + dist) + line_width = 2 + + draw_edge = draw_edge_int > 0 + if draw_edge: + if draw_edge_int == 4: + linewidth_factor = 1.5 + else: + # linewidth_factor = draw_edge_int # Prop to number of + # edges. + linewidth_factor = 1 + ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], + linewidth=line_width * linewidth_factor, + c=hex_bg_color, alpha=alpha) + + +def plot_data3d(positions, atom_type, dataset_info, camera_elev=0, + camera_azim=0, save_path=None, spheres_3d=False, + bg='black', alpha=1.): + black = (0, 0, 0) + white = (1, 1, 1) + hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' + + from mpl_toolkits.mplot3d import Axes3D + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set_aspect('auto') + ax.view_init(elev=camera_elev, azim=camera_azim) + if bg == 'black': + ax.set_facecolor(black) + else: + ax.set_facecolor(white) + # ax.xaxis.pane.set_edgecolor('#D0D0D0') + ax.xaxis.pane.set_alpha(0) + ax.yaxis.pane.set_alpha(0) + ax.zaxis.pane.set_alpha(0) + ax._axis3don = False + + if bg == 'black': + ax.w_xaxis.line.set_color("black") + else: + ax.w_xaxis.line.set_color("white") + + plot_molecule(ax, positions, atom_type, alpha, spheres_3d, + hex_bg_color, dataset_info) + + # if 'qm9' in dataset_info['name']: + max_value = positions.abs().max().item() + + # axis_lim = 3.2 + axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) + ax.set_xlim(-axis_lim, axis_lim) + ax.set_ylim(-axis_lim, axis_lim) + ax.set_zlim(-axis_lim, axis_lim) + # elif dataset_info['name'] == 'geom': + # max_value = positions.abs().max().item() + # + # # axis_lim = 3.2 + # axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) + # ax.set_xlim(-axis_lim, axis_lim) + # ax.set_ylim(-axis_lim, axis_lim) + # ax.set_zlim(-axis_lim, axis_lim) + # elif dataset_info['name'] == 'pdbbind': + # max_value = positions.abs().max().item() + # + # # axis_lim = 3.2 + # axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) + # ax.set_xlim(-axis_lim, axis_lim) + # ax.set_ylim(-axis_lim, axis_lim) + # ax.set_zlim(-axis_lim, axis_lim) + # else: + # raise ValueError(dataset_info['name']) + + dpi = 120 if spheres_3d else 50 + + if save_path is not None: + plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) + + if spheres_3d: + img = imageio.imread(save_path) + img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') + imageio.imsave(save_path, img_brighter) + else: + plt.show() + plt.close() + + +def plot_data3d_uncertainty( + all_positions, all_atom_types, dataset_info, camera_elev=0, + camera_azim=0, + save_path=None, spheres_3d=False, bg='black', alpha=1.): + black = (0, 0, 0) + white = (1, 1, 1) + hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' + + from mpl_toolkits.mplot3d import Axes3D + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set_aspect('auto') + ax.view_init(elev=camera_elev, azim=camera_azim) + if bg == 'black': + ax.set_facecolor(black) + else: + ax.set_facecolor(white) + # ax.xaxis.pane.set_edgecolor('#D0D0D0') + ax.xaxis.pane.set_alpha(0) + ax.yaxis.pane.set_alpha(0) + ax.zaxis.pane.set_alpha(0) + ax._axis3don = False + + if bg == 'black': + ax.w_xaxis.line.set_color("black") + else: + ax.w_xaxis.line.set_color("white") + + for i in range(len(all_positions)): + positions = all_positions[i] + atom_type = all_atom_types[i] + plot_molecule(ax, positions, atom_type, alpha, spheres_3d, + hex_bg_color, dataset_info) + + if 'qm9' in dataset_info['name']: + max_value = all_positions[0].abs().max().item() + + # axis_lim = 3.2 + axis_lim = min(40, max(max_value + 0.3, 3.2)) + ax.set_xlim(-axis_lim, axis_lim) + ax.set_ylim(-axis_lim, axis_lim) + ax.set_zlim(-axis_lim, axis_lim) + elif dataset_info['name'] == 'geom': + max_value = all_positions[0].abs().max().item() + + # axis_lim = 3.2 + axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) + ax.set_xlim(-axis_lim, axis_lim) + ax.set_ylim(-axis_lim, axis_lim) + ax.set_zlim(-axis_lim, axis_lim) + elif dataset_info['name'] == 'pdbbind': + max_value = all_positions[0].abs().max().item() + + # axis_lim = 3.2 + axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) + ax.set_xlim(-axis_lim, axis_lim) + ax.set_ylim(-axis_lim, axis_lim) + ax.set_zlim(-axis_lim, axis_lim) + else: + raise ValueError(dataset_info['name']) + + dpi = 120 if spheres_3d else 50 + + if save_path is not None: + plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) + + if spheres_3d: + img = imageio.imread(save_path) + img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') + imageio.imsave(save_path, img_brighter) + else: + plt.show() + plt.close() + + +def plot_grid(): + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import ImageGrid + + im1 = np.arange(100).reshape((10, 10)) + im2 = im1.T + im3 = np.flipud(im1) + im4 = np.fliplr(im2) + + fig = plt.figure(figsize=(10., 10.)) + grid = ImageGrid(fig, 111, # similar to subplot(111) + nrows_ncols=(6, 6), # creates 2x2 grid of axes + axes_pad=0.1, # pad between axes in inch. + ) + + for ax, im in zip(grid, [im1, im2, im3, im4]): + # Iterating over the grid returns the Axes. + + ax.imshow(im) + + plt.show() + + +def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False): + files = load_xyz_files(path)[0:max_num] + for file in files: + positions, one_hot = load_molecule_xyz(file, dataset_info) + atom_type = torch.argmax(one_hot, dim=1).numpy() + dists = torch.cdist(positions.unsqueeze(0), + positions.unsqueeze(0)).squeeze(0) + dists = dists[dists > 0] + # print("Average distance between atoms", dists.mean().item()) + plot_data3d(positions, atom_type, dataset_info=dataset_info, + save_path=file[:-4] + '.png', + spheres_3d=spheres_3d) + + if wandb is not None: + path = file[:-4] + '.png' + # Log image(s) + im = plt.imread(path) + wandb.log({'molecule': [wandb.Image(im, caption=path)]}) + + +def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False, + mode="chain"): + files = load_xyz_files(path) + files = sorted(files) + save_paths = [] + + for i in range(len(files)): + file = files[i] + + positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info) + + atom_type = torch.argmax(one_hot, dim=1).numpy() + fn = file[:-4] + '.png' + plot_data3d(positions, atom_type, dataset_info=dataset_info, + save_path=fn, spheres_3d=spheres_3d, alpha=1.0) + save_paths.append(fn) + + imgs = [imageio.imread(fn) for fn in save_paths] + dirname = os.path.dirname(save_paths[0]) + gif_path = dirname + '/output.gif' + print(f'Creating gif with {len(imgs)} images') + # Add the last frame 10 times so that the final result remains temporally. + # imgs.extend([imgs[-1]] * 10) + imageio.mimsave(gif_path, imgs, subrectangles=True) + + if wandb is not None: + wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) + + +def visualize_chain_uncertainty( + path, dataset_info, wandb=None, spheres_3d=False, mode="chain"): + files = load_xyz_files(path) + files = sorted(files) + save_paths = [] + + for i in range(len(files)): + if i + 2 == len(files): + break + + file = files[i] + file2 = files[i + 1] + file3 = files[i + 2] + + positions, one_hot, _ = load_molecule_xyz(file, + dataset_info=dataset_info) + positions2, one_hot2, _ = load_molecule_xyz( + file2, dataset_info=dataset_info) + positions3, one_hot3, _ = load_molecule_xyz( + file3, dataset_info=dataset_info) + + all_positions = torch.stack([positions, positions2, positions3], dim=0) + one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0) + + all_atom_type = torch.argmax(one_hot, dim=2).numpy() + fn = file[:-4] + '.png' + plot_data3d_uncertainty( + all_positions, all_atom_type, dataset_info=dataset_info, + save_path=fn, spheres_3d=spheres_3d, alpha=0.5) + save_paths.append(fn) + + imgs = [imageio.imread(fn) for fn in save_paths] + dirname = os.path.dirname(save_paths[0]) + gif_path = dirname + '/output.gif' + print(f'Creating gif with {len(imgs)} images') + # Add the last frame 10 times so that the final result remains temporally. + # imgs.extend([imgs[-1]] * 10) + imageio.mimsave(gif_path, imgs, subrectangles=True) + + if wandb is not None: + wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) + + +if __name__ == '__main__': + # plot_grid() + import qm9.dataset as dataset + from configs.datasets_config import qm9_with_h, geom_with_h + + matplotlib.use('macosx') + + task = "visualize_molecules" + task_dataset = 'geom' + + if task_dataset == 'qm9': + dataset_info = qm9_with_h + + + class Args: + batch_size = 1 + num_workers = 0 + filter_n_atoms = None + datadir = 'qm9/temp' + dataset = 'qm9' + remove_h = False + + + cfg = Args() + + dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg) + + for i, data in enumerate(dataloaders['train']): + positions = data['positions'].view(-1, 3) + positions_centered = positions - positions.mean(dim=0, keepdim=True) + one_hot = data['one_hot'].view(-1, 5).type(torch.float32) + atom_type = torch.argmax(one_hot, dim=1).numpy() + + plot_data3d( + positions_centered, atom_type, dataset_info=dataset_info, + spheres_3d=True) + + elif task_dataset == 'geom': + files = load_xyz_files('outputs/data') + matplotlib.use('macosx') + for file in files: + x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h) + + positions = x.view(-1, 3) + positions_centered = positions - positions.mean(dim=0, keepdim=True) + one_hot = one_hot.view(-1, 16).type(torch.float32) + atom_type = torch.argmax(one_hot, dim=1).numpy() + + mask = (x == 0).sum(1) != 3 + positions_centered = positions_centered[mask] + atom_type = atom_type[mask] + + plot_data3d( + positions_centered, atom_type, dataset_info=geom_with_h, + spheres_3d=False) + + else: + raise ValueError(dataset)