Diff of /train.py [000000] .. [7d53f6]

Switch to unified view

a b/train.py
1
import os
2
import time
3
import random
4
import pickle
5
import argparse
6
import os.path as osp
7
8
import torch
9
import torch.utils.data
10
from torch import nn
11
from torch_geometric.loader import DataLoader
12
13
import wandb
14
from rdkit import RDLogger
15
16
torch.set_num_threads(5)
17
RDLogger.DisableLog('rdApp.*')
18
19
from src.util.utils import *
20
from src.model.models import Generator, Discriminator, simple_disc
21
from src.data.dataset import DruggenDataset
22
from src.data.utils import get_encoders_decoders, load_molecules
23
from src.model.loss import discriminator_loss, generator_loss
24
25
class Train(object):
26
    """Trainer for DrugGEN."""
27
28
    def __init__(self, config):
29
        if config.set_seed:
30
            np.random.seed(config.seed)
31
            random.seed(config.seed)
32
            torch.manual_seed(config.seed)
33
            torch.cuda.manual_seed_all(config.seed)
34
35
            torch.backends.cudnn.deterministic = True
36
            torch.backends.cudnn.benchmark = False
37
38
            os.environ["PYTHONHASHSEED"] = str(config.seed)
39
40
            print(f'Using seed {config.seed}')
41
42
        self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
43
44
        # Initialize configurations
45
        self.submodel = config.submodel
46
47
        # Data loader.
48
        self.raw_file = config.raw_file  # SMILES containing text file for dataset. 
49
                                         # Write the full path to file.
50
        self.drug_raw_file = config.drug_raw_file  # SMILES containing text file for second dataset. 
51
                                                   # Write the full path to file.
52
        
53
        # Automatically infer dataset file names from raw file names
54
        raw_file_basename = osp.basename(self.raw_file)
55
        drug_raw_file_basename = osp.basename(self.drug_raw_file)
56
        
57
        # Get the base name without extension and add max_atom to it
58
        self.max_atom = config.max_atom  # Model is based on one-shot generation.
59
        raw_file_base = os.path.splitext(raw_file_basename)[0]
60
        drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0]
61
62
        # Change extension from .smi to .pt and add max_atom to the filename
63
        self.dataset_file = f"{raw_file_base}{self.max_atom}.pt"
64
        self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt"
65
66
        self.mol_data_dir = config.mol_data_dir  # Directory where the dataset files are stored.
67
        self.drug_data_dir = config.drug_data_dir  # Directory where the drug dataset files are stored.
68
        self.dataset_name = self.dataset_file.split(".")[0]
69
        self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0]
70
        self.features = config.features  # Small model uses atom types as node features. (Boolean, False uses atom types only.)
71
                                         # Additional node features can be added. Please check new_dataloarder.py Line 102.
72
        self.batch_size = config.batch_size  # Batch size for training.
73
        
74
        self.parallel = config.parallel
75
76
        # Get atom and bond encoders/decoders
77
        atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders(
78
            self.raw_file,
79
            self.drug_raw_file,
80
            self.max_atom
81
        )
82
        self.atom_encoder = atom_encoder
83
        self.atom_decoder = atom_decoder
84
        self.bond_encoder = bond_encoder
85
        self.bond_decoder = bond_decoder
86
87
        self.dataset = DruggenDataset(self.mol_data_dir,
88
                                     self.dataset_file,
89
                                     self.raw_file,
90
                                     self.max_atom,
91
                                     self.features,
92
                                     atom_encoder=atom_encoder,
93
                                     atom_decoder=atom_decoder,
94
                                     bond_encoder=bond_encoder,
95
                                     bond_decoder=bond_decoder)
96
97
        self.loader = DataLoader(self.dataset,
98
                                 shuffle=True,
99
                                 batch_size=self.batch_size,
100
                                 drop_last=True)  # PyG dataloader for the GAN.
101
102
        self.drugs = DruggenDataset(self.drug_data_dir, 
103
                                 self.drugs_dataset_file, 
104
                                 self.drug_raw_file, 
105
                                 self.max_atom, 
106
                                 self.features,
107
                                 atom_encoder=atom_encoder,
108
                                 atom_decoder=atom_decoder,
109
                                 bond_encoder=bond_encoder,
110
                                 bond_decoder=bond_decoder)
111
112
        self.drugs_loader = DataLoader(self.drugs, 
113
                                       shuffle=True,
114
                                       batch_size=self.batch_size, 
115
                                       drop_last=True)  # PyG dataloader for the second GAN.
116
117
        self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
118
        self.b_dim = len(self.bond_decoder) # Bond type dimension.
119
        self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
120
121
        # Model configurations.
122
        self.act = config.act
123
        self.lambda_gp = config.lambda_gp
124
        self.dim = config.dim
125
        self.depth = config.depth
126
        self.heads = config.heads
127
        self.mlp_ratio = config.mlp_ratio
128
        self.ddepth = config.ddepth
129
        self.ddropout = config.ddropout
130
131
        # Training configurations.
132
        self.epoch = config.epoch
133
        self.g_lr = config.g_lr
134
        self.d_lr = config.d_lr
135
        self.dropout = config.dropout
136
        self.beta1 = config.beta1
137
        self.beta2 = config.beta2
138
139
        # Directories.
140
        self.log_dir = config.log_dir
141
        self.sample_dir = config.sample_dir
142
        self.model_save_dir = config.model_save_dir
143
144
        # Step size.
145
        self.log_step = config.log_sample_step
146
147
        # resume training
148
        self.resume = config.resume
149
        self.resume_epoch = config.resume_epoch
150
        self.resume_iter = config.resume_iter
151
        self.resume_directory = config.resume_directory
152
153
        # wandb configuration
154
        self.use_wandb = config.use_wandb
155
        self.online = config.online
156
        self.exp_name = config.exp_name
157
158
        # Arguments for the model.
159
        self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout)
160
161
        self.build_model(self.model_save_dir, self.arguments)
162
163
164
    def build_model(self, model_save_dir, arguments):
165
        """Create generators and discriminators."""
166
        
167
        ''' Generator is based on Transformer Encoder: 
168
            
169
            @ g_conv_dim: Dimensions for MLP layers before Transformer Encoder
170
            @ vertexes: maximum length of generated molecules (atom length)
171
            @ b_dim: number of bond types
172
            @ m_dim: number of atom types (or number of features used)
173
            @ dropout: dropout possibility
174
            @ dim: Hidden dimension of Transformer Encoder
175
            @ depth: Transformer layer number
176
            @ heads: Number of multihead-attention heads
177
            @ mlp_ratio: Read-out layer dimension of Transformer
178
            @ drop_rate: depricated  
179
            @ tra_conv: Whether module creates output for TransformerConv discriminator
180
            '''
181
        self.G = Generator(self.act,
182
                           self.vertexes,
183
                           self.b_dim,
184
                           self.m_dim,
185
                           self.dropout,
186
                           dim=self.dim,
187
                           depth=self.depth,
188
                           heads=self.heads,
189
                           mlp_ratio=self.mlp_ratio)
190
191
        ''' Discriminator implementation with Transformer Encoder:
192
            
193
            @ act: Activation function for MLP
194
            @ vertexes: maximum length of generated molecules (molecule length)
195
            @ b_dim: number of bond types
196
            @ m_dim: number of atom types (or number of features used)
197
            @ dropout: dropout possibility
198
            @ dim: Hidden dimension of Transformer Encoder
199
            @ depth: Transformer layer number
200
            @ heads: Number of multihead-attention heads
201
            @ mlp_ratio: Read-out layer dimension of Transformer'''
202
203
        self.D = Discriminator(self.act,
204
                                self.vertexes,
205
                                self.b_dim,
206
                                self.m_dim,
207
                                self.ddropout,
208
                                dim=self.dim,
209
                                depth=self.ddepth,
210
                                heads=self.heads,
211
                                mlp_ratio=self.mlp_ratio)
212
213
        self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
214
        self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
215
216
        network_path = os.path.join(model_save_dir, arguments)
217
        self.print_network(self.G, 'G', network_path)
218
        self.print_network(self.D, 'D', network_path)
219
220
        if self.parallel and torch.cuda.device_count() > 1:
221
            print(f"Using {torch.cuda.device_count()} GPUs!")
222
            self.G = nn.DataParallel(self.G)
223
            self.D = nn.DataParallel(self.D)
224
225
        self.G.to(self.device)
226
        self.D.to(self.device)
227
228
    def print_network(self, model, name, save_dir):
229
        """Print out the network information."""
230
        num_params = 0
231
        for p in model.parameters():
232
            num_params += p.numel()
233
234
        if not os.path.exists(save_dir):
235
            os.makedirs(save_dir)
236
237
        network_path = os.path.join(save_dir, "{}_modules.txt".format(name))
238
        with open(network_path, "w+") as file:
239
            for module in model.modules():
240
                file.write(f"{module.__class__.__name__}:\n")
241
                print(module.__class__.__name__)
242
                for n, param in module.named_parameters():
243
                    if param is not None:
244
                        file.write(f"  - {n}: {param.size()}\n")
245
                        print(f"  - {n}: {param.size()}")
246
                break
247
            file.write(f"Total number of parameters: {num_params}\n")
248
            print(f"Total number of parameters: {num_params}\n\n")
249
250
    def restore_model(self, epoch, iteration, model_directory):
251
        """Restore the trained generator and discriminator."""
252
        print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
253
254
        G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
255
        D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
256
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
257
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
258
259
    def save_model(self, model_directory, idx,i):
260
        G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
261
        D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
262
        torch.save(self.G.state_dict(), G_path)
263
        torch.save(self.D.state_dict(), D_path)
264
265
    def reset_grad(self):
266
        """Reset the gradient buffers."""
267
        self.g_optimizer.zero_grad()
268
        self.d_optimizer.zero_grad()
269
270
    def train(self, config):
271
        ''' Training Script starts from here'''
272
        if self.use_wandb:
273
            mode = 'online' if self.online else 'offline'
274
        else:
275
            mode = 'disabled'
276
        kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config,
277
                'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True}
278
        wandb.init(**kwargs)
279
280
        wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt"))
281
        wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt"))
282
283
        self.model_directory = os.path.join(self.model_save_dir, self.arguments)
284
        self.sample_directory = os.path.join(self.sample_dir, self.arguments)
285
        self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
286
        if not os.path.exists(self.model_directory):
287
            os.makedirs(self.model_directory)
288
        if not os.path.exists(self.sample_directory):
289
            os.makedirs(self.sample_directory)
290
291
        # smiles data for metrics calculation.
292
        drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()]
293
        drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
294
        drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
295
296
        if self.resume:
297
            self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
298
299
        # Start training.
300
        print('Start training...')
301
        self.start_time = time.time()
302
        for idx in range(self.epoch):
303
            # =================================================================================== #
304
            #                             1. Preprocess input data                                #
305
            # =================================================================================== #
306
            # Load the data
307
            dataloader_iterator = iter(self.drugs_loader)
308
309
            wandb.log({"epoch": idx})
310
311
            for i, data in enumerate(self.loader):
312
                try:
313
                    drugs = next(dataloader_iterator)
314
                except StopIteration:
315
                    dataloader_iterator = iter(self.drugs_loader)
316
                    drugs = next(dataloader_iterator)
317
318
                wandb.log({"iter": i})
319
320
                # Preprocess both dataset
321
                real_graphs, a_tensor, x_tensor = load_molecules(
322
                    data=data,
323
                    batch_size=self.batch_size,
324
                    device=self.device,
325
                    b_dim=self.b_dim,
326
                    m_dim=self.m_dim,
327
                )
328
329
                drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
330
                    data=drugs,
331
                    batch_size=self.batch_size,
332
                    device=self.device,
333
                    b_dim=self.b_dim,
334
                    m_dim=self.m_dim,
335
                )
336
337
                # Training configuration.
338
                GEN_node = x_tensor             # Generator input node features (annotation matrix of real molecules)
339
                GEN_edge = a_tensor             # Generator input edge features (adjacency matrix of real molecules)
340
                if self.submodel == "DrugGEN":
341
                    DISC_node = drugs_x_tensor  # Discriminator input node features (annotation matrix of drug molecules)
342
                    DISC_edge = drugs_a_tensor  # Discriminator input edge features (adjacency matrix of drug molecules)
343
                elif self.submodel == "NoTarget":
344
                    DISC_node = x_tensor      # Discriminator input node features (annotation matrix of real molecules)
345
                    DISC_edge = a_tensor      # Discriminator input edge features (adjacency matrix of real molecules)
346
347
                # =================================================================================== #
348
                #                                     2. Train the GAN                                #
349
                # =================================================================================== #
350
351
                loss = {}
352
                self.reset_grad()
353
                # Compute discriminator loss.
354
                node, edge, d_loss = discriminator_loss(self.G,
355
                                            self.D,
356
                                            DISC_edge,
357
                                            DISC_node,
358
                                            GEN_edge,
359
                                            GEN_node,
360
                                            self.batch_size,
361
                                            self.device,
362
                                            self.lambda_gp)
363
                d_total = d_loss
364
                wandb.log({"d_loss": d_total.item()})
365
366
                loss["d_total"] = d_total.item()
367
                d_total.backward()
368
                self.d_optimizer.step()
369
370
                self.reset_grad()
371
372
                # Compute generator loss.
373
                generator_output = generator_loss(self.G,
374
                                                    self.D,
375
                                                    GEN_edge,
376
                                                    GEN_node,
377
                                                    self.batch_size)
378
                g_loss, node, edge, node_sample, edge_sample = generator_output
379
                g_total = g_loss
380
                wandb.log({"g_loss": g_total.item()})
381
382
                loss["g_total"] = g_total.item()
383
                g_total.backward()
384
                self.g_optimizer.step()
385
386
                # Logging.
387
                if (i+1) % self.log_step == 0:
388
                    logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory,
389
                            drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol,
390
                            self.dataset_name, a_tensor, x_tensor, drug_vecs)
391
392
                    mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(),
393
                               idx, i, self.dataset.matrices2mol, self.dataset_name)
394
                    print("samples saved at epoch {} and iteration {}".format(idx,i))
395
396
                    self.save_model(self.model_directory, idx, i)
397
                    print("model saved at epoch {} and iteration {}".format(idx,i))
398
399
400
if __name__ == '__main__':
401
    parser = argparse.ArgumentParser()
402
403
    # Data configuration.
404
    parser.add_argument('--raw_file', type=str, required=True)
405
    parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget')
406
    parser.add_argument('--drug_data_dir', type=str, default='data')
407
    parser.add_argument('--mol_data_dir', type=str, default='data')
408
    parser.add_argument('--features', action='store_true', help='features dimension for nodes')
409
410
    # Model configuration.
411
    parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
412
    parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
413
    parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
414
    parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
415
    parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
416
    parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.')
417
    parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
418
    parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
419
    parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
420
    parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator')
421
    parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.')
422
423
    # Training configuration.
424
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.')
425
    parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.')
426
    parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G')
427
    parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D')
428
    parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
429
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
430
    parser.add_argument('--log_dir', type=str, default='experiments/logs')
431
    parser.add_argument('--sample_dir', type=str, default='experiments/samples')
432
    parser.add_argument('--model_save_dir', type=str, default='experiments/models')
433
    parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training')
434
435
    # Resume training.
436
    parser.add_argument('--resume', type=bool, default=False, help='resume training')
437
    parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
438
    parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step')
439
    parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory')
440
441
    # Seed configuration.
442
    parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
443
    parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
444
445
    # wandb configuration.
446
    parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging')
447
    parser.add_argument('--online', action='store_true', help='use wandb online')
448
    parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name')
449
    parser.add_argument('--parallel', action='store_true', help='Parallelize training')
450
451
    config = parser.parse_args()
452
453
    # Check if drug_raw_file is provided when using DrugGEN model
454
    if config.submodel == "DrugGEN" and not config.drug_raw_file:
455
        parser.error("--drug_raw_file is required when using DrugGEN model")
456
457
    # If using NoTarget model and drug_raw_file is not provided, use a dummy file
458
    if config.submodel == "NoTarget" and not config.drug_raw_file:
459
        config.drug_raw_file = "data/akt_train.smi"  # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's)
460
461
    trainer = Train(config)
462
    trainer.train(config)