Diff of /src/data/utils.py [000000] .. [7d53f6]

Switch to unified view

a b/src/data/utils.py
1
import os
2
import pickle
3
4
import pandas as pd
5
from tqdm import tqdm
6
7
import torch
8
from torch_geometric.data import Data, InMemoryDataset
9
import torch_geometric.utils as geoutils
10
11
from rdkit import Chem, RDLogger
12
13
14
15
def label2onehot(labels, dim, device=None):
16
    """Convert label indices to one-hot vectors."""
17
    out = torch.zeros(list(labels.size())+[dim])
18
    if device:
19
        out = out.to(device)
20
21
    out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
22
23
    return out.float()
24
25
26
def get_encoders_decoders(raw_file1, raw_file2, max_atom):
27
    """
28
    Given two raw SMILES files, either load the atom and bond encoders/decoders
29
    if they exist (naming them based on the file names) or create and save them.
30
31
    Parameters:
32
        raw_file1 (str): Path to the first SMILES file.
33
        raw_file2 (str): Path to the second SMILES file.
34
        max_atom (int): Maximum allowed number of atoms in a molecule.
35
36
    Returns:
37
        atom_encoder (dict): Mapping from atomic numbers to indices.
38
        atom_decoder (dict): Mapping from indices to atomic numbers.
39
        bond_encoder (dict): Mapping from bond types to indices.
40
        bond_decoder (dict): Mapping from indices to bond types.
41
    """
42
    # Determine unique suffix based on the two file names (alphabetically sorted for consistency)
43
    name1 = os.path.splitext(os.path.basename(raw_file1))[0]
44
    name2 = os.path.splitext(os.path.basename(raw_file2))[0]
45
    sorted_names = sorted([name1, name2])
46
    suffix = f"{sorted_names[0]}_{sorted_names[1]}"
47
48
    # Define encoder/decoder directories and file paths
49
    enc_dir = os.path.join("data", "encoders")
50
    dec_dir = os.path.join("data", "decoders")
51
    atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
52
    atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
53
    bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
54
    bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")
55
56
    # If all files exist, load and return them
57
    if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and 
58
        os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
59
        with open(atom_encoder_path, "rb") as f:
60
            atom_encoder = pickle.load(f)
61
        with open(atom_decoder_path, "rb") as f:
62
            atom_decoder = pickle.load(f)
63
        with open(bond_encoder_path, "rb") as f:
64
            bond_encoder = pickle.load(f)
65
        with open(bond_decoder_path, "rb") as f:
66
            bond_decoder = pickle.load(f)
67
        print("Loaded existing encoders/decoders!")
68
        return atom_encoder, atom_decoder, bond_encoder, bond_decoder
69
70
    # Otherwise, create the encoders/decoders
71
    print("Creating new encoders/decoders...")
72
    # Read SMILES from both files (assuming one SMILES per row, no header)
73
    smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
74
    smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
75
    smiles_combined = smiles1 + smiles2
76
77
    atom_labels = set()
78
    bond_labels = set()
79
    max_length = 0
80
    filtered_smiles = []
81
    
82
    # Process each SMILES: keep only valid molecules with <= max_atom atoms
83
    for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
84
        mol = Chem.MolFromSmiles(smiles)
85
        if mol is None:
86
            continue
87
        molecule_size = mol.GetNumAtoms()
88
        if molecule_size > max_atom:
89
            continue
90
        filtered_smiles.append(smiles)
91
        # Collect atomic numbers
92
        atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
93
        max_length = max(max_length, molecule_size)
94
        # Collect bond types
95
        bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
96
    
97
    # Add a PAD symbol (here using 0 for atoms)
98
    atom_labels.add(0)
99
    atom_labels = sorted(atom_labels)
100
    
101
    # For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
102
    bond_labels = sorted(bond_labels)
103
    bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
104
105
    # Create encoder and decoder dictionaries
106
    atom_encoder = {l: i for i, l in enumerate(atom_labels)}
107
    atom_decoder = {i: l for i, l in enumerate(atom_labels)}
108
    bond_encoder = {l: i for i, l in enumerate(bond_labels)}
109
    bond_decoder = {i: l for i, l in enumerate(bond_labels)}
110
111
    # Ensure directories exist
112
    os.makedirs(enc_dir, exist_ok=True)
113
    os.makedirs(dec_dir, exist_ok=True)
114
115
    # Save the encoders/decoders to disk
116
    with open(atom_encoder_path, "wb") as f:
117
        pickle.dump(atom_encoder, f)
118
    with open(atom_decoder_path, "wb") as f:
119
        pickle.dump(atom_decoder, f)
120
    with open(bond_encoder_path, "wb") as f:
121
        pickle.dump(bond_encoder, f)
122
    with open(bond_decoder_path, "wb") as f:
123
        pickle.dump(bond_decoder, f)
124
125
    print("Encoders/decoders created and saved.")
126
    return atom_encoder, atom_decoder, bond_encoder, bond_decoder
127
128
def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
129
    data = data.to(device)
130
    a = geoutils.to_dense_adj(
131
        edge_index = data.edge_index,
132
        batch=data.batch,
133
        edge_attr=data.edge_attr,
134
        max_num_nodes=int(data.batch.shape[0]/batch_size)
135
    )
136
    x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
137
    a_tensor = label2onehot(a, b_dim, device)
138
139
    a_tensor_vec = a_tensor.reshape(batch_size,-1)
140
    x_tensor_vec = x_tensor.reshape(batch_size,-1)
141
    real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
142
143
    return real_graphs, a_tensor, x_tensor