[d33920]: / src / data / utils.py

Download this file

143 lines (117 with data), 5.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import pickle
import pandas as pd
from tqdm import tqdm
import torch
from torch_geometric.data import Data, InMemoryDataset
import torch_geometric.utils as geoutils
from rdkit import Chem, RDLogger
def label2onehot(labels, dim, device=None):
"""Convert label indices to one-hot vectors."""
out = torch.zeros(list(labels.size())+[dim])
if device:
out = out.to(device)
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
return out.float()
def get_encoders_decoders(raw_file1, raw_file2, max_atom):
"""
Given two raw SMILES files, either load the atom and bond encoders/decoders
if they exist (naming them based on the file names) or create and save them.
Parameters:
raw_file1 (str): Path to the first SMILES file.
raw_file2 (str): Path to the second SMILES file.
max_atom (int): Maximum allowed number of atoms in a molecule.
Returns:
atom_encoder (dict): Mapping from atomic numbers to indices.
atom_decoder (dict): Mapping from indices to atomic numbers.
bond_encoder (dict): Mapping from bond types to indices.
bond_decoder (dict): Mapping from indices to bond types.
"""
# Determine unique suffix based on the two file names (alphabetically sorted for consistency)
name1 = os.path.splitext(os.path.basename(raw_file1))[0]
name2 = os.path.splitext(os.path.basename(raw_file2))[0]
sorted_names = sorted([name1, name2])
suffix = f"{sorted_names[0]}_{sorted_names[1]}"
# Define encoder/decoder directories and file paths
enc_dir = os.path.join("data", "encoders")
dec_dir = os.path.join("data", "decoders")
atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")
# If all files exist, load and return them
if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and
os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
with open(atom_encoder_path, "rb") as f:
atom_encoder = pickle.load(f)
with open(atom_decoder_path, "rb") as f:
atom_decoder = pickle.load(f)
with open(bond_encoder_path, "rb") as f:
bond_encoder = pickle.load(f)
with open(bond_decoder_path, "rb") as f:
bond_decoder = pickle.load(f)
print("Loaded existing encoders/decoders!")
return atom_encoder, atom_decoder, bond_encoder, bond_decoder
# Otherwise, create the encoders/decoders
print("Creating new encoders/decoders...")
# Read SMILES from both files (assuming one SMILES per row, no header)
smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
smiles_combined = smiles1 + smiles2
atom_labels = set()
bond_labels = set()
max_length = 0
filtered_smiles = []
# Process each SMILES: keep only valid molecules with <= max_atom atoms
for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
continue
molecule_size = mol.GetNumAtoms()
if molecule_size > max_atom:
continue
filtered_smiles.append(smiles)
# Collect atomic numbers
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
max_length = max(max_length, molecule_size)
# Collect bond types
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
# Add a PAD symbol (here using 0 for atoms)
atom_labels.add(0)
atom_labels = sorted(atom_labels)
# For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
bond_labels = sorted(bond_labels)
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
# Create encoder and decoder dictionaries
atom_encoder = {l: i for i, l in enumerate(atom_labels)}
atom_decoder = {i: l for i, l in enumerate(atom_labels)}
bond_encoder = {l: i for i, l in enumerate(bond_labels)}
bond_decoder = {i: l for i, l in enumerate(bond_labels)}
# Ensure directories exist
os.makedirs(enc_dir, exist_ok=True)
os.makedirs(dec_dir, exist_ok=True)
# Save the encoders/decoders to disk
with open(atom_encoder_path, "wb") as f:
pickle.dump(atom_encoder, f)
with open(atom_decoder_path, "wb") as f:
pickle.dump(atom_decoder, f)
with open(bond_encoder_path, "wb") as f:
pickle.dump(bond_encoder, f)
with open(bond_decoder_path, "wb") as f:
pickle.dump(bond_decoder, f)
print("Encoders/decoders created and saved.")
return atom_encoder, atom_decoder, bond_encoder, bond_decoder
def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
data = data.to(device)
a = geoutils.to_dense_adj(
edge_index = data.edge_index,
batch=data.batch,
edge_attr=data.edge_attr,
max_num_nodes=int(data.batch.shape[0]/batch_size)
)
x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
a_tensor = label2onehot(a, b_dim, device)
a_tensor_vec = a_tensor.reshape(batch_size,-1)
x_tensor_vec = x_tensor.reshape(batch_size,-1)
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
return real_graphs, a_tensor, x_tensor