|
a |
|
b/analysis/visualization.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import os |
|
|
4 |
import glob |
|
|
5 |
import random |
|
|
6 |
import matplotlib |
|
|
7 |
import imageio |
|
|
8 |
|
|
|
9 |
matplotlib.use('Agg') |
|
|
10 |
import matplotlib.pyplot as plt |
|
|
11 |
from analysis.molecule_builder import get_bond_order |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
############## |
|
|
15 |
### Files #### |
|
|
16 |
###########--> |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, |
|
|
20 |
name='molecule', batch_mask=None): |
|
|
21 |
try: |
|
|
22 |
os.makedirs(path) |
|
|
23 |
except OSError: |
|
|
24 |
pass |
|
|
25 |
|
|
|
26 |
if batch_mask is None: |
|
|
27 |
batch_mask = torch.zeros(len(one_hot)) |
|
|
28 |
|
|
|
29 |
for batch_i in torch.unique(batch_mask): |
|
|
30 |
cur_batch_mask = (batch_mask == batch_i) |
|
|
31 |
n_atoms = int(torch.sum(cur_batch_mask).item()) |
|
|
32 |
f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w") |
|
|
33 |
f.write("%d\n\n" % n_atoms) |
|
|
34 |
atoms = torch.argmax(one_hot[cur_batch_mask], dim=1) |
|
|
35 |
batch_pos = positions[cur_batch_mask] |
|
|
36 |
for atom_i in range(n_atoms): |
|
|
37 |
atom = atoms[atom_i] |
|
|
38 |
atom = atom_decoder[atom] |
|
|
39 |
f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2])) |
|
|
40 |
f.close() |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def load_molecule_xyz(file, dataset_info): |
|
|
44 |
with open(file, encoding='utf8') as f: |
|
|
45 |
n_atoms = int(f.readline()) |
|
|
46 |
one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder'])) |
|
|
47 |
positions = torch.zeros(n_atoms, 3) |
|
|
48 |
f.readline() |
|
|
49 |
atoms = f.readlines() |
|
|
50 |
for i in range(n_atoms): |
|
|
51 |
atom = atoms[i].split(' ') |
|
|
52 |
atom_type = atom[0] |
|
|
53 |
one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1 |
|
|
54 |
position = torch.Tensor([float(e) for e in atom[1:]]) |
|
|
55 |
positions[i, :] = position |
|
|
56 |
return positions, one_hot |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
def load_xyz_files(path, shuffle=True): |
|
|
60 |
files = glob.glob(path + "/*.xyz") |
|
|
61 |
if shuffle: |
|
|
62 |
random.shuffle(files) |
|
|
63 |
return files |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
# <----######## |
|
|
67 |
### Files #### |
|
|
68 |
############## |
|
|
69 |
def draw_sphere(ax, x, y, z, size, color, alpha): |
|
|
70 |
u = np.linspace(0, 2 * np.pi, 100) |
|
|
71 |
v = np.linspace(0, np.pi, 100) |
|
|
72 |
|
|
|
73 |
xs = size * np.outer(np.cos(u), np.sin(v)) |
|
|
74 |
ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8 # Correct for matplotlib. |
|
|
75 |
zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) |
|
|
76 |
# for i in range(2): |
|
|
77 |
# ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5), rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5) |
|
|
78 |
|
|
|
79 |
ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, |
|
|
80 |
linewidth=0, |
|
|
81 |
alpha=alpha) |
|
|
82 |
# # calculate vectors for "vertical" circle |
|
|
83 |
# a = np.array([-np.sin(elev / 180 * np.pi), 0, np.cos(elev / 180 * np.pi)]) |
|
|
84 |
# b = np.array([0, 1, 0]) |
|
|
85 |
# b = b * np.cos(rot) + np.cross(a, b) * np.sin(rot) + a * np.dot(a, b) * ( |
|
|
86 |
# 1 - np.cos(rot)) |
|
|
87 |
# ax.plot(np.sin(u), np.cos(u), 0, color='k', linestyle='dashed') |
|
|
88 |
# horiz_front = np.linspace(0, np.pi, 100) |
|
|
89 |
# ax.plot(np.sin(horiz_front), np.cos(horiz_front), 0, color='k') |
|
|
90 |
# vert_front = np.linspace(np.pi / 2, 3 * np.pi / 2, 100) |
|
|
91 |
# ax.plot(a[0] * np.sin(u) + b[0] * np.cos(u), b[1] * np.cos(u), |
|
|
92 |
# a[2] * np.sin(u) + b[2] * np.cos(u), color='k', linestyle='dashed') |
|
|
93 |
# ax.plot(a[0] * np.sin(vert_front) + b[0] * np.cos(vert_front), |
|
|
94 |
# b[1] * np.cos(vert_front), |
|
|
95 |
# a[2] * np.sin(vert_front) + b[2] * np.cos(vert_front), color='k') |
|
|
96 |
# |
|
|
97 |
# ax.view_init(elev=elev, azim=0) |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, |
|
|
101 |
dataset_info): |
|
|
102 |
# draw_sphere(ax, 0, 0, 0, 1) |
|
|
103 |
# draw_sphere(ax, 1, 1, 1, 1) |
|
|
104 |
|
|
|
105 |
x = positions[:, 0] |
|
|
106 |
y = positions[:, 1] |
|
|
107 |
z = positions[:, 2] |
|
|
108 |
# Hydrogen, Carbon, Nitrogen, Oxygen, Flourine |
|
|
109 |
|
|
|
110 |
# ax.set_facecolor((1.0, 0.47, 0.42)) |
|
|
111 |
colors_dic = np.array(dataset_info['colors_dic']) |
|
|
112 |
radius_dic = np.array(dataset_info['radius_dic']) |
|
|
113 |
area_dic = 1500 * radius_dic ** 2 |
|
|
114 |
# areas_dic = sizes_dic * sizes_dic * 3.1416 |
|
|
115 |
|
|
|
116 |
areas = area_dic[atom_type] |
|
|
117 |
radii = radius_dic[atom_type] |
|
|
118 |
colors = colors_dic[atom_type] |
|
|
119 |
|
|
|
120 |
if spheres_3d: |
|
|
121 |
for i, j, k, s, c in zip(x, y, z, radii, colors): |
|
|
122 |
draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha) |
|
|
123 |
else: |
|
|
124 |
ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, |
|
|
125 |
c=colors) # , linewidths=2, edgecolors='#FFFFFF') |
|
|
126 |
|
|
|
127 |
for i in range(len(x)): |
|
|
128 |
for j in range(i + 1, len(x)): |
|
|
129 |
p1 = np.array([x[i], y[i], z[i]]) |
|
|
130 |
p2 = np.array([x[j], y[j], z[j]]) |
|
|
131 |
dist = np.sqrt(np.sum((p1 - p2) ** 2)) |
|
|
132 |
atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \ |
|
|
133 |
dataset_info['atom_decoder'][atom_type[j]] |
|
|
134 |
s = (atom_type[i], atom_type[j]) |
|
|
135 |
|
|
|
136 |
draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]], |
|
|
137 |
dataset_info['atom_decoder'][s[1]], |
|
|
138 |
dist) |
|
|
139 |
line_width = 2 |
|
|
140 |
|
|
|
141 |
draw_edge = draw_edge_int > 0 |
|
|
142 |
if draw_edge: |
|
|
143 |
if draw_edge_int == 4: |
|
|
144 |
linewidth_factor = 1.5 |
|
|
145 |
else: |
|
|
146 |
# linewidth_factor = draw_edge_int # Prop to number of |
|
|
147 |
# edges. |
|
|
148 |
linewidth_factor = 1 |
|
|
149 |
ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], |
|
|
150 |
linewidth=line_width * linewidth_factor, |
|
|
151 |
c=hex_bg_color, alpha=alpha) |
|
|
152 |
|
|
|
153 |
|
|
|
154 |
def plot_data3d(positions, atom_type, dataset_info, camera_elev=0, |
|
|
155 |
camera_azim=0, save_path=None, spheres_3d=False, |
|
|
156 |
bg='black', alpha=1.): |
|
|
157 |
black = (0, 0, 0) |
|
|
158 |
white = (1, 1, 1) |
|
|
159 |
hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' |
|
|
160 |
|
|
|
161 |
from mpl_toolkits.mplot3d import Axes3D |
|
|
162 |
fig = plt.figure() |
|
|
163 |
ax = fig.add_subplot(projection='3d') |
|
|
164 |
ax.set_aspect('auto') |
|
|
165 |
ax.view_init(elev=camera_elev, azim=camera_azim) |
|
|
166 |
if bg == 'black': |
|
|
167 |
ax.set_facecolor(black) |
|
|
168 |
else: |
|
|
169 |
ax.set_facecolor(white) |
|
|
170 |
# ax.xaxis.pane.set_edgecolor('#D0D0D0') |
|
|
171 |
ax.xaxis.pane.set_alpha(0) |
|
|
172 |
ax.yaxis.pane.set_alpha(0) |
|
|
173 |
ax.zaxis.pane.set_alpha(0) |
|
|
174 |
ax._axis3don = False |
|
|
175 |
|
|
|
176 |
if bg == 'black': |
|
|
177 |
ax.w_xaxis.line.set_color("black") |
|
|
178 |
else: |
|
|
179 |
ax.w_xaxis.line.set_color("white") |
|
|
180 |
|
|
|
181 |
plot_molecule(ax, positions, atom_type, alpha, spheres_3d, |
|
|
182 |
hex_bg_color, dataset_info) |
|
|
183 |
|
|
|
184 |
# if 'qm9' in dataset_info['name']: |
|
|
185 |
max_value = positions.abs().max().item() |
|
|
186 |
|
|
|
187 |
# axis_lim = 3.2 |
|
|
188 |
axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) |
|
|
189 |
ax.set_xlim(-axis_lim, axis_lim) |
|
|
190 |
ax.set_ylim(-axis_lim, axis_lim) |
|
|
191 |
ax.set_zlim(-axis_lim, axis_lim) |
|
|
192 |
# elif dataset_info['name'] == 'geom': |
|
|
193 |
# max_value = positions.abs().max().item() |
|
|
194 |
# |
|
|
195 |
# # axis_lim = 3.2 |
|
|
196 |
# axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) |
|
|
197 |
# ax.set_xlim(-axis_lim, axis_lim) |
|
|
198 |
# ax.set_ylim(-axis_lim, axis_lim) |
|
|
199 |
# ax.set_zlim(-axis_lim, axis_lim) |
|
|
200 |
# elif dataset_info['name'] == 'pdbbind': |
|
|
201 |
# max_value = positions.abs().max().item() |
|
|
202 |
# |
|
|
203 |
# # axis_lim = 3.2 |
|
|
204 |
# axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) |
|
|
205 |
# ax.set_xlim(-axis_lim, axis_lim) |
|
|
206 |
# ax.set_ylim(-axis_lim, axis_lim) |
|
|
207 |
# ax.set_zlim(-axis_lim, axis_lim) |
|
|
208 |
# else: |
|
|
209 |
# raise ValueError(dataset_info['name']) |
|
|
210 |
|
|
|
211 |
dpi = 120 if spheres_3d else 50 |
|
|
212 |
|
|
|
213 |
if save_path is not None: |
|
|
214 |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) |
|
|
215 |
|
|
|
216 |
if spheres_3d: |
|
|
217 |
img = imageio.imread(save_path) |
|
|
218 |
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') |
|
|
219 |
imageio.imsave(save_path, img_brighter) |
|
|
220 |
else: |
|
|
221 |
plt.show() |
|
|
222 |
plt.close() |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
def plot_data3d_uncertainty( |
|
|
226 |
all_positions, all_atom_types, dataset_info, camera_elev=0, |
|
|
227 |
camera_azim=0, |
|
|
228 |
save_path=None, spheres_3d=False, bg='black', alpha=1.): |
|
|
229 |
black = (0, 0, 0) |
|
|
230 |
white = (1, 1, 1) |
|
|
231 |
hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' |
|
|
232 |
|
|
|
233 |
from mpl_toolkits.mplot3d import Axes3D |
|
|
234 |
fig = plt.figure() |
|
|
235 |
ax = fig.add_subplot(projection='3d') |
|
|
236 |
ax.set_aspect('auto') |
|
|
237 |
ax.view_init(elev=camera_elev, azim=camera_azim) |
|
|
238 |
if bg == 'black': |
|
|
239 |
ax.set_facecolor(black) |
|
|
240 |
else: |
|
|
241 |
ax.set_facecolor(white) |
|
|
242 |
# ax.xaxis.pane.set_edgecolor('#D0D0D0') |
|
|
243 |
ax.xaxis.pane.set_alpha(0) |
|
|
244 |
ax.yaxis.pane.set_alpha(0) |
|
|
245 |
ax.zaxis.pane.set_alpha(0) |
|
|
246 |
ax._axis3don = False |
|
|
247 |
|
|
|
248 |
if bg == 'black': |
|
|
249 |
ax.w_xaxis.line.set_color("black") |
|
|
250 |
else: |
|
|
251 |
ax.w_xaxis.line.set_color("white") |
|
|
252 |
|
|
|
253 |
for i in range(len(all_positions)): |
|
|
254 |
positions = all_positions[i] |
|
|
255 |
atom_type = all_atom_types[i] |
|
|
256 |
plot_molecule(ax, positions, atom_type, alpha, spheres_3d, |
|
|
257 |
hex_bg_color, dataset_info) |
|
|
258 |
|
|
|
259 |
if 'qm9' in dataset_info['name']: |
|
|
260 |
max_value = all_positions[0].abs().max().item() |
|
|
261 |
|
|
|
262 |
# axis_lim = 3.2 |
|
|
263 |
axis_lim = min(40, max(max_value + 0.3, 3.2)) |
|
|
264 |
ax.set_xlim(-axis_lim, axis_lim) |
|
|
265 |
ax.set_ylim(-axis_lim, axis_lim) |
|
|
266 |
ax.set_zlim(-axis_lim, axis_lim) |
|
|
267 |
elif dataset_info['name'] == 'geom': |
|
|
268 |
max_value = all_positions[0].abs().max().item() |
|
|
269 |
|
|
|
270 |
# axis_lim = 3.2 |
|
|
271 |
axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) |
|
|
272 |
ax.set_xlim(-axis_lim, axis_lim) |
|
|
273 |
ax.set_ylim(-axis_lim, axis_lim) |
|
|
274 |
ax.set_zlim(-axis_lim, axis_lim) |
|
|
275 |
elif dataset_info['name'] == 'pdbbind': |
|
|
276 |
max_value = all_positions[0].abs().max().item() |
|
|
277 |
|
|
|
278 |
# axis_lim = 3.2 |
|
|
279 |
axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) |
|
|
280 |
ax.set_xlim(-axis_lim, axis_lim) |
|
|
281 |
ax.set_ylim(-axis_lim, axis_lim) |
|
|
282 |
ax.set_zlim(-axis_lim, axis_lim) |
|
|
283 |
else: |
|
|
284 |
raise ValueError(dataset_info['name']) |
|
|
285 |
|
|
|
286 |
dpi = 120 if spheres_3d else 50 |
|
|
287 |
|
|
|
288 |
if save_path is not None: |
|
|
289 |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) |
|
|
290 |
|
|
|
291 |
if spheres_3d: |
|
|
292 |
img = imageio.imread(save_path) |
|
|
293 |
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') |
|
|
294 |
imageio.imsave(save_path, img_brighter) |
|
|
295 |
else: |
|
|
296 |
plt.show() |
|
|
297 |
plt.close() |
|
|
298 |
|
|
|
299 |
|
|
|
300 |
def plot_grid(): |
|
|
301 |
import matplotlib.pyplot as plt |
|
|
302 |
from mpl_toolkits.axes_grid1 import ImageGrid |
|
|
303 |
|
|
|
304 |
im1 = np.arange(100).reshape((10, 10)) |
|
|
305 |
im2 = im1.T |
|
|
306 |
im3 = np.flipud(im1) |
|
|
307 |
im4 = np.fliplr(im2) |
|
|
308 |
|
|
|
309 |
fig = plt.figure(figsize=(10., 10.)) |
|
|
310 |
grid = ImageGrid(fig, 111, # similar to subplot(111) |
|
|
311 |
nrows_ncols=(6, 6), # creates 2x2 grid of axes |
|
|
312 |
axes_pad=0.1, # pad between axes in inch. |
|
|
313 |
) |
|
|
314 |
|
|
|
315 |
for ax, im in zip(grid, [im1, im2, im3, im4]): |
|
|
316 |
# Iterating over the grid returns the Axes. |
|
|
317 |
|
|
|
318 |
ax.imshow(im) |
|
|
319 |
|
|
|
320 |
plt.show() |
|
|
321 |
|
|
|
322 |
|
|
|
323 |
def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False): |
|
|
324 |
files = load_xyz_files(path)[0:max_num] |
|
|
325 |
for file in files: |
|
|
326 |
positions, one_hot = load_molecule_xyz(file, dataset_info) |
|
|
327 |
atom_type = torch.argmax(one_hot, dim=1).numpy() |
|
|
328 |
dists = torch.cdist(positions.unsqueeze(0), |
|
|
329 |
positions.unsqueeze(0)).squeeze(0) |
|
|
330 |
dists = dists[dists > 0] |
|
|
331 |
# print("Average distance between atoms", dists.mean().item()) |
|
|
332 |
plot_data3d(positions, atom_type, dataset_info=dataset_info, |
|
|
333 |
save_path=file[:-4] + '.png', |
|
|
334 |
spheres_3d=spheres_3d) |
|
|
335 |
|
|
|
336 |
if wandb is not None: |
|
|
337 |
path = file[:-4] + '.png' |
|
|
338 |
# Log image(s) |
|
|
339 |
im = plt.imread(path) |
|
|
340 |
wandb.log({'molecule': [wandb.Image(im, caption=path)]}) |
|
|
341 |
|
|
|
342 |
|
|
|
343 |
def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False, |
|
|
344 |
mode="chain"): |
|
|
345 |
files = load_xyz_files(path) |
|
|
346 |
files = sorted(files) |
|
|
347 |
save_paths = [] |
|
|
348 |
|
|
|
349 |
for i in range(len(files)): |
|
|
350 |
file = files[i] |
|
|
351 |
|
|
|
352 |
positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info) |
|
|
353 |
|
|
|
354 |
atom_type = torch.argmax(one_hot, dim=1).numpy() |
|
|
355 |
fn = file[:-4] + '.png' |
|
|
356 |
plot_data3d(positions, atom_type, dataset_info=dataset_info, |
|
|
357 |
save_path=fn, spheres_3d=spheres_3d, alpha=1.0) |
|
|
358 |
save_paths.append(fn) |
|
|
359 |
|
|
|
360 |
imgs = [imageio.imread(fn) for fn in save_paths] |
|
|
361 |
dirname = os.path.dirname(save_paths[0]) |
|
|
362 |
gif_path = dirname + '/output.gif' |
|
|
363 |
print(f'Creating gif with {len(imgs)} images') |
|
|
364 |
# Add the last frame 10 times so that the final result remains temporally. |
|
|
365 |
# imgs.extend([imgs[-1]] * 10) |
|
|
366 |
imageio.mimsave(gif_path, imgs, subrectangles=True) |
|
|
367 |
|
|
|
368 |
if wandb is not None: |
|
|
369 |
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) |
|
|
370 |
|
|
|
371 |
|
|
|
372 |
def visualize_chain_uncertainty( |
|
|
373 |
path, dataset_info, wandb=None, spheres_3d=False, mode="chain"): |
|
|
374 |
files = load_xyz_files(path) |
|
|
375 |
files = sorted(files) |
|
|
376 |
save_paths = [] |
|
|
377 |
|
|
|
378 |
for i in range(len(files)): |
|
|
379 |
if i + 2 == len(files): |
|
|
380 |
break |
|
|
381 |
|
|
|
382 |
file = files[i] |
|
|
383 |
file2 = files[i + 1] |
|
|
384 |
file3 = files[i + 2] |
|
|
385 |
|
|
|
386 |
positions, one_hot, _ = load_molecule_xyz(file, |
|
|
387 |
dataset_info=dataset_info) |
|
|
388 |
positions2, one_hot2, _ = load_molecule_xyz( |
|
|
389 |
file2, dataset_info=dataset_info) |
|
|
390 |
positions3, one_hot3, _ = load_molecule_xyz( |
|
|
391 |
file3, dataset_info=dataset_info) |
|
|
392 |
|
|
|
393 |
all_positions = torch.stack([positions, positions2, positions3], dim=0) |
|
|
394 |
one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0) |
|
|
395 |
|
|
|
396 |
all_atom_type = torch.argmax(one_hot, dim=2).numpy() |
|
|
397 |
fn = file[:-4] + '.png' |
|
|
398 |
plot_data3d_uncertainty( |
|
|
399 |
all_positions, all_atom_type, dataset_info=dataset_info, |
|
|
400 |
save_path=fn, spheres_3d=spheres_3d, alpha=0.5) |
|
|
401 |
save_paths.append(fn) |
|
|
402 |
|
|
|
403 |
imgs = [imageio.imread(fn) for fn in save_paths] |
|
|
404 |
dirname = os.path.dirname(save_paths[0]) |
|
|
405 |
gif_path = dirname + '/output.gif' |
|
|
406 |
print(f'Creating gif with {len(imgs)} images') |
|
|
407 |
# Add the last frame 10 times so that the final result remains temporally. |
|
|
408 |
# imgs.extend([imgs[-1]] * 10) |
|
|
409 |
imageio.mimsave(gif_path, imgs, subrectangles=True) |
|
|
410 |
|
|
|
411 |
if wandb is not None: |
|
|
412 |
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) |
|
|
413 |
|
|
|
414 |
|
|
|
415 |
if __name__ == '__main__': |
|
|
416 |
# plot_grid() |
|
|
417 |
import qm9.dataset as dataset |
|
|
418 |
from configs.datasets_config import qm9_with_h, geom_with_h |
|
|
419 |
|
|
|
420 |
matplotlib.use('macosx') |
|
|
421 |
|
|
|
422 |
task = "visualize_molecules" |
|
|
423 |
task_dataset = 'geom' |
|
|
424 |
|
|
|
425 |
if task_dataset == 'qm9': |
|
|
426 |
dataset_info = qm9_with_h |
|
|
427 |
|
|
|
428 |
|
|
|
429 |
class Args: |
|
|
430 |
batch_size = 1 |
|
|
431 |
num_workers = 0 |
|
|
432 |
filter_n_atoms = None |
|
|
433 |
datadir = 'qm9/temp' |
|
|
434 |
dataset = 'qm9' |
|
|
435 |
remove_h = False |
|
|
436 |
|
|
|
437 |
|
|
|
438 |
cfg = Args() |
|
|
439 |
|
|
|
440 |
dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg) |
|
|
441 |
|
|
|
442 |
for i, data in enumerate(dataloaders['train']): |
|
|
443 |
positions = data['positions'].view(-1, 3) |
|
|
444 |
positions_centered = positions - positions.mean(dim=0, keepdim=True) |
|
|
445 |
one_hot = data['one_hot'].view(-1, 5).type(torch.float32) |
|
|
446 |
atom_type = torch.argmax(one_hot, dim=1).numpy() |
|
|
447 |
|
|
|
448 |
plot_data3d( |
|
|
449 |
positions_centered, atom_type, dataset_info=dataset_info, |
|
|
450 |
spheres_3d=True) |
|
|
451 |
|
|
|
452 |
elif task_dataset == 'geom': |
|
|
453 |
files = load_xyz_files('outputs/data') |
|
|
454 |
matplotlib.use('macosx') |
|
|
455 |
for file in files: |
|
|
456 |
x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h) |
|
|
457 |
|
|
|
458 |
positions = x.view(-1, 3) |
|
|
459 |
positions_centered = positions - positions.mean(dim=0, keepdim=True) |
|
|
460 |
one_hot = one_hot.view(-1, 16).type(torch.float32) |
|
|
461 |
atom_type = torch.argmax(one_hot, dim=1).numpy() |
|
|
462 |
|
|
|
463 |
mask = (x == 0).sum(1) != 3 |
|
|
464 |
positions_centered = positions_centered[mask] |
|
|
465 |
atom_type = atom_type[mask] |
|
|
466 |
|
|
|
467 |
plot_data3d( |
|
|
468 |
positions_centered, atom_type, dataset_info=geom_with_h, |
|
|
469 |
spheres_3d=False) |
|
|
470 |
|
|
|
471 |
else: |
|
|
472 |
raise ValueError(dataset) |