|
a |
|
b/src/data/utils.py |
|
|
1 |
import os |
|
|
2 |
import pickle |
|
|
3 |
|
|
|
4 |
import pandas as pd |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
from torch_geometric.data import Data, InMemoryDataset |
|
|
9 |
import torch_geometric.utils as geoutils |
|
|
10 |
|
|
|
11 |
from rdkit import Chem, RDLogger |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def label2onehot(labels, dim, device=None): |
|
|
16 |
"""Convert label indices to one-hot vectors.""" |
|
|
17 |
out = torch.zeros(list(labels.size())+[dim]) |
|
|
18 |
if device: |
|
|
19 |
out = out.to(device) |
|
|
20 |
|
|
|
21 |
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.) |
|
|
22 |
|
|
|
23 |
return out.float() |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def get_encoders_decoders(raw_file1, raw_file2, max_atom): |
|
|
27 |
""" |
|
|
28 |
Given two raw SMILES files, either load the atom and bond encoders/decoders |
|
|
29 |
if they exist (naming them based on the file names) or create and save them. |
|
|
30 |
|
|
|
31 |
Parameters: |
|
|
32 |
raw_file1 (str): Path to the first SMILES file. |
|
|
33 |
raw_file2 (str): Path to the second SMILES file. |
|
|
34 |
max_atom (int): Maximum allowed number of atoms in a molecule. |
|
|
35 |
|
|
|
36 |
Returns: |
|
|
37 |
atom_encoder (dict): Mapping from atomic numbers to indices. |
|
|
38 |
atom_decoder (dict): Mapping from indices to atomic numbers. |
|
|
39 |
bond_encoder (dict): Mapping from bond types to indices. |
|
|
40 |
bond_decoder (dict): Mapping from indices to bond types. |
|
|
41 |
""" |
|
|
42 |
# Determine unique suffix based on the two file names (alphabetically sorted for consistency) |
|
|
43 |
name1 = os.path.splitext(os.path.basename(raw_file1))[0] |
|
|
44 |
name2 = os.path.splitext(os.path.basename(raw_file2))[0] |
|
|
45 |
sorted_names = sorted([name1, name2]) |
|
|
46 |
suffix = f"{sorted_names[0]}_{sorted_names[1]}" |
|
|
47 |
|
|
|
48 |
# Define encoder/decoder directories and file paths |
|
|
49 |
enc_dir = os.path.join("data", "encoders") |
|
|
50 |
dec_dir = os.path.join("data", "decoders") |
|
|
51 |
atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl") |
|
|
52 |
atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl") |
|
|
53 |
bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl") |
|
|
54 |
bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl") |
|
|
55 |
|
|
|
56 |
# If all files exist, load and return them |
|
|
57 |
if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and |
|
|
58 |
os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)): |
|
|
59 |
with open(atom_encoder_path, "rb") as f: |
|
|
60 |
atom_encoder = pickle.load(f) |
|
|
61 |
with open(atom_decoder_path, "rb") as f: |
|
|
62 |
atom_decoder = pickle.load(f) |
|
|
63 |
with open(bond_encoder_path, "rb") as f: |
|
|
64 |
bond_encoder = pickle.load(f) |
|
|
65 |
with open(bond_decoder_path, "rb") as f: |
|
|
66 |
bond_decoder = pickle.load(f) |
|
|
67 |
print("Loaded existing encoders/decoders!") |
|
|
68 |
return atom_encoder, atom_decoder, bond_encoder, bond_decoder |
|
|
69 |
|
|
|
70 |
# Otherwise, create the encoders/decoders |
|
|
71 |
print("Creating new encoders/decoders...") |
|
|
72 |
# Read SMILES from both files (assuming one SMILES per row, no header) |
|
|
73 |
smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist() |
|
|
74 |
smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist() |
|
|
75 |
smiles_combined = smiles1 + smiles2 |
|
|
76 |
|
|
|
77 |
atom_labels = set() |
|
|
78 |
bond_labels = set() |
|
|
79 |
max_length = 0 |
|
|
80 |
filtered_smiles = [] |
|
|
81 |
|
|
|
82 |
# Process each SMILES: keep only valid molecules with <= max_atom atoms |
|
|
83 |
for smiles in tqdm(smiles_combined, desc="Processing SMILES"): |
|
|
84 |
mol = Chem.MolFromSmiles(smiles) |
|
|
85 |
if mol is None: |
|
|
86 |
continue |
|
|
87 |
molecule_size = mol.GetNumAtoms() |
|
|
88 |
if molecule_size > max_atom: |
|
|
89 |
continue |
|
|
90 |
filtered_smiles.append(smiles) |
|
|
91 |
# Collect atomic numbers |
|
|
92 |
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()]) |
|
|
93 |
max_length = max(max_length, molecule_size) |
|
|
94 |
# Collect bond types |
|
|
95 |
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()]) |
|
|
96 |
|
|
|
97 |
# Add a PAD symbol (here using 0 for atoms) |
|
|
98 |
atom_labels.add(0) |
|
|
99 |
atom_labels = sorted(atom_labels) |
|
|
100 |
|
|
|
101 |
# For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO) |
|
|
102 |
bond_labels = sorted(bond_labels) |
|
|
103 |
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels |
|
|
104 |
|
|
|
105 |
# Create encoder and decoder dictionaries |
|
|
106 |
atom_encoder = {l: i for i, l in enumerate(atom_labels)} |
|
|
107 |
atom_decoder = {i: l for i, l in enumerate(atom_labels)} |
|
|
108 |
bond_encoder = {l: i for i, l in enumerate(bond_labels)} |
|
|
109 |
bond_decoder = {i: l for i, l in enumerate(bond_labels)} |
|
|
110 |
|
|
|
111 |
# Ensure directories exist |
|
|
112 |
os.makedirs(enc_dir, exist_ok=True) |
|
|
113 |
os.makedirs(dec_dir, exist_ok=True) |
|
|
114 |
|
|
|
115 |
# Save the encoders/decoders to disk |
|
|
116 |
with open(atom_encoder_path, "wb") as f: |
|
|
117 |
pickle.dump(atom_encoder, f) |
|
|
118 |
with open(atom_decoder_path, "wb") as f: |
|
|
119 |
pickle.dump(atom_decoder, f) |
|
|
120 |
with open(bond_encoder_path, "wb") as f: |
|
|
121 |
pickle.dump(bond_encoder, f) |
|
|
122 |
with open(bond_decoder_path, "wb") as f: |
|
|
123 |
pickle.dump(bond_decoder, f) |
|
|
124 |
|
|
|
125 |
print("Encoders/decoders created and saved.") |
|
|
126 |
return atom_encoder, atom_decoder, bond_encoder, bond_decoder |
|
|
127 |
|
|
|
128 |
def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32): |
|
|
129 |
data = data.to(device) |
|
|
130 |
a = geoutils.to_dense_adj( |
|
|
131 |
edge_index = data.edge_index, |
|
|
132 |
batch=data.batch, |
|
|
133 |
edge_attr=data.edge_attr, |
|
|
134 |
max_num_nodes=int(data.batch.shape[0]/batch_size) |
|
|
135 |
) |
|
|
136 |
x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1) |
|
|
137 |
a_tensor = label2onehot(a, b_dim, device) |
|
|
138 |
|
|
|
139 |
a_tensor_vec = a_tensor.reshape(batch_size,-1) |
|
|
140 |
x_tensor_vec = x_tensor.reshape(batch_size,-1) |
|
|
141 |
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1) |
|
|
142 |
|
|
|
143 |
return real_graphs, a_tensor, x_tensor |