|
a |
|
b/train.py |
|
|
1 |
import os |
|
|
2 |
import time |
|
|
3 |
import random |
|
|
4 |
import pickle |
|
|
5 |
import argparse |
|
|
6 |
import os.path as osp |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
import torch.utils.data |
|
|
10 |
from torch import nn |
|
|
11 |
from torch_geometric.loader import DataLoader |
|
|
12 |
|
|
|
13 |
import wandb |
|
|
14 |
from rdkit import RDLogger |
|
|
15 |
|
|
|
16 |
torch.set_num_threads(5) |
|
|
17 |
RDLogger.DisableLog('rdApp.*') |
|
|
18 |
|
|
|
19 |
from src.util.utils import * |
|
|
20 |
from src.model.models import Generator, Discriminator, simple_disc |
|
|
21 |
from src.data.dataset import DruggenDataset |
|
|
22 |
from src.data.utils import get_encoders_decoders, load_molecules |
|
|
23 |
from src.model.loss import discriminator_loss, generator_loss |
|
|
24 |
|
|
|
25 |
class Train(object): |
|
|
26 |
"""Trainer for DrugGEN.""" |
|
|
27 |
|
|
|
28 |
def __init__(self, config): |
|
|
29 |
if config.set_seed: |
|
|
30 |
np.random.seed(config.seed) |
|
|
31 |
random.seed(config.seed) |
|
|
32 |
torch.manual_seed(config.seed) |
|
|
33 |
torch.cuda.manual_seed_all(config.seed) |
|
|
34 |
|
|
|
35 |
torch.backends.cudnn.deterministic = True |
|
|
36 |
torch.backends.cudnn.benchmark = False |
|
|
37 |
|
|
|
38 |
os.environ["PYTHONHASHSEED"] = str(config.seed) |
|
|
39 |
|
|
|
40 |
print(f'Using seed {config.seed}') |
|
|
41 |
|
|
|
42 |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') |
|
|
43 |
|
|
|
44 |
# Initialize configurations |
|
|
45 |
self.submodel = config.submodel |
|
|
46 |
|
|
|
47 |
# Data loader. |
|
|
48 |
self.raw_file = config.raw_file # SMILES containing text file for dataset. |
|
|
49 |
# Write the full path to file. |
|
|
50 |
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset. |
|
|
51 |
# Write the full path to file. |
|
|
52 |
|
|
|
53 |
# Automatically infer dataset file names from raw file names |
|
|
54 |
raw_file_basename = osp.basename(self.raw_file) |
|
|
55 |
drug_raw_file_basename = osp.basename(self.drug_raw_file) |
|
|
56 |
|
|
|
57 |
# Get the base name without extension and add max_atom to it |
|
|
58 |
self.max_atom = config.max_atom # Model is based on one-shot generation. |
|
|
59 |
raw_file_base = os.path.splitext(raw_file_basename)[0] |
|
|
60 |
drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0] |
|
|
61 |
|
|
|
62 |
# Change extension from .smi to .pt and add max_atom to the filename |
|
|
63 |
self.dataset_file = f"{raw_file_base}{self.max_atom}.pt" |
|
|
64 |
self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt" |
|
|
65 |
|
|
|
66 |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored. |
|
|
67 |
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored. |
|
|
68 |
self.dataset_name = self.dataset_file.split(".")[0] |
|
|
69 |
self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0] |
|
|
70 |
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.) |
|
|
71 |
# Additional node features can be added. Please check new_dataloarder.py Line 102. |
|
|
72 |
self.batch_size = config.batch_size # Batch size for training. |
|
|
73 |
|
|
|
74 |
self.parallel = config.parallel |
|
|
75 |
|
|
|
76 |
# Get atom and bond encoders/decoders |
|
|
77 |
atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders( |
|
|
78 |
self.raw_file, |
|
|
79 |
self.drug_raw_file, |
|
|
80 |
self.max_atom |
|
|
81 |
) |
|
|
82 |
self.atom_encoder = atom_encoder |
|
|
83 |
self.atom_decoder = atom_decoder |
|
|
84 |
self.bond_encoder = bond_encoder |
|
|
85 |
self.bond_decoder = bond_decoder |
|
|
86 |
|
|
|
87 |
self.dataset = DruggenDataset(self.mol_data_dir, |
|
|
88 |
self.dataset_file, |
|
|
89 |
self.raw_file, |
|
|
90 |
self.max_atom, |
|
|
91 |
self.features, |
|
|
92 |
atom_encoder=atom_encoder, |
|
|
93 |
atom_decoder=atom_decoder, |
|
|
94 |
bond_encoder=bond_encoder, |
|
|
95 |
bond_decoder=bond_decoder) |
|
|
96 |
|
|
|
97 |
self.loader = DataLoader(self.dataset, |
|
|
98 |
shuffle=True, |
|
|
99 |
batch_size=self.batch_size, |
|
|
100 |
drop_last=True) # PyG dataloader for the GAN. |
|
|
101 |
|
|
|
102 |
self.drugs = DruggenDataset(self.drug_data_dir, |
|
|
103 |
self.drugs_dataset_file, |
|
|
104 |
self.drug_raw_file, |
|
|
105 |
self.max_atom, |
|
|
106 |
self.features, |
|
|
107 |
atom_encoder=atom_encoder, |
|
|
108 |
atom_decoder=atom_decoder, |
|
|
109 |
bond_encoder=bond_encoder, |
|
|
110 |
bond_decoder=bond_decoder) |
|
|
111 |
|
|
|
112 |
self.drugs_loader = DataLoader(self.drugs, |
|
|
113 |
shuffle=True, |
|
|
114 |
batch_size=self.batch_size, |
|
|
115 |
drop_last=True) # PyG dataloader for the second GAN. |
|
|
116 |
|
|
|
117 |
self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension. |
|
|
118 |
self.b_dim = len(self.bond_decoder) # Bond type dimension. |
|
|
119 |
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph. |
|
|
120 |
|
|
|
121 |
# Model configurations. |
|
|
122 |
self.act = config.act |
|
|
123 |
self.lambda_gp = config.lambda_gp |
|
|
124 |
self.dim = config.dim |
|
|
125 |
self.depth = config.depth |
|
|
126 |
self.heads = config.heads |
|
|
127 |
self.mlp_ratio = config.mlp_ratio |
|
|
128 |
self.ddepth = config.ddepth |
|
|
129 |
self.ddropout = config.ddropout |
|
|
130 |
|
|
|
131 |
# Training configurations. |
|
|
132 |
self.epoch = config.epoch |
|
|
133 |
self.g_lr = config.g_lr |
|
|
134 |
self.d_lr = config.d_lr |
|
|
135 |
self.dropout = config.dropout |
|
|
136 |
self.beta1 = config.beta1 |
|
|
137 |
self.beta2 = config.beta2 |
|
|
138 |
|
|
|
139 |
# Directories. |
|
|
140 |
self.log_dir = config.log_dir |
|
|
141 |
self.sample_dir = config.sample_dir |
|
|
142 |
self.model_save_dir = config.model_save_dir |
|
|
143 |
|
|
|
144 |
# Step size. |
|
|
145 |
self.log_step = config.log_sample_step |
|
|
146 |
|
|
|
147 |
# resume training |
|
|
148 |
self.resume = config.resume |
|
|
149 |
self.resume_epoch = config.resume_epoch |
|
|
150 |
self.resume_iter = config.resume_iter |
|
|
151 |
self.resume_directory = config.resume_directory |
|
|
152 |
|
|
|
153 |
# wandb configuration |
|
|
154 |
self.use_wandb = config.use_wandb |
|
|
155 |
self.online = config.online |
|
|
156 |
self.exp_name = config.exp_name |
|
|
157 |
|
|
|
158 |
# Arguments for the model. |
|
|
159 |
self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout) |
|
|
160 |
|
|
|
161 |
self.build_model(self.model_save_dir, self.arguments) |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
def build_model(self, model_save_dir, arguments): |
|
|
165 |
"""Create generators and discriminators.""" |
|
|
166 |
|
|
|
167 |
''' Generator is based on Transformer Encoder: |
|
|
168 |
|
|
|
169 |
@ g_conv_dim: Dimensions for MLP layers before Transformer Encoder |
|
|
170 |
@ vertexes: maximum length of generated molecules (atom length) |
|
|
171 |
@ b_dim: number of bond types |
|
|
172 |
@ m_dim: number of atom types (or number of features used) |
|
|
173 |
@ dropout: dropout possibility |
|
|
174 |
@ dim: Hidden dimension of Transformer Encoder |
|
|
175 |
@ depth: Transformer layer number |
|
|
176 |
@ heads: Number of multihead-attention heads |
|
|
177 |
@ mlp_ratio: Read-out layer dimension of Transformer |
|
|
178 |
@ drop_rate: depricated |
|
|
179 |
@ tra_conv: Whether module creates output for TransformerConv discriminator |
|
|
180 |
''' |
|
|
181 |
self.G = Generator(self.act, |
|
|
182 |
self.vertexes, |
|
|
183 |
self.b_dim, |
|
|
184 |
self.m_dim, |
|
|
185 |
self.dropout, |
|
|
186 |
dim=self.dim, |
|
|
187 |
depth=self.depth, |
|
|
188 |
heads=self.heads, |
|
|
189 |
mlp_ratio=self.mlp_ratio) |
|
|
190 |
|
|
|
191 |
''' Discriminator implementation with Transformer Encoder: |
|
|
192 |
|
|
|
193 |
@ act: Activation function for MLP |
|
|
194 |
@ vertexes: maximum length of generated molecules (molecule length) |
|
|
195 |
@ b_dim: number of bond types |
|
|
196 |
@ m_dim: number of atom types (or number of features used) |
|
|
197 |
@ dropout: dropout possibility |
|
|
198 |
@ dim: Hidden dimension of Transformer Encoder |
|
|
199 |
@ depth: Transformer layer number |
|
|
200 |
@ heads: Number of multihead-attention heads |
|
|
201 |
@ mlp_ratio: Read-out layer dimension of Transformer''' |
|
|
202 |
|
|
|
203 |
self.D = Discriminator(self.act, |
|
|
204 |
self.vertexes, |
|
|
205 |
self.b_dim, |
|
|
206 |
self.m_dim, |
|
|
207 |
self.ddropout, |
|
|
208 |
dim=self.dim, |
|
|
209 |
depth=self.ddepth, |
|
|
210 |
heads=self.heads, |
|
|
211 |
mlp_ratio=self.mlp_ratio) |
|
|
212 |
|
|
|
213 |
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) |
|
|
214 |
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) |
|
|
215 |
|
|
|
216 |
network_path = os.path.join(model_save_dir, arguments) |
|
|
217 |
self.print_network(self.G, 'G', network_path) |
|
|
218 |
self.print_network(self.D, 'D', network_path) |
|
|
219 |
|
|
|
220 |
if self.parallel and torch.cuda.device_count() > 1: |
|
|
221 |
print(f"Using {torch.cuda.device_count()} GPUs!") |
|
|
222 |
self.G = nn.DataParallel(self.G) |
|
|
223 |
self.D = nn.DataParallel(self.D) |
|
|
224 |
|
|
|
225 |
self.G.to(self.device) |
|
|
226 |
self.D.to(self.device) |
|
|
227 |
|
|
|
228 |
def print_network(self, model, name, save_dir): |
|
|
229 |
"""Print out the network information.""" |
|
|
230 |
num_params = 0 |
|
|
231 |
for p in model.parameters(): |
|
|
232 |
num_params += p.numel() |
|
|
233 |
|
|
|
234 |
if not os.path.exists(save_dir): |
|
|
235 |
os.makedirs(save_dir) |
|
|
236 |
|
|
|
237 |
network_path = os.path.join(save_dir, "{}_modules.txt".format(name)) |
|
|
238 |
with open(network_path, "w+") as file: |
|
|
239 |
for module in model.modules(): |
|
|
240 |
file.write(f"{module.__class__.__name__}:\n") |
|
|
241 |
print(module.__class__.__name__) |
|
|
242 |
for n, param in module.named_parameters(): |
|
|
243 |
if param is not None: |
|
|
244 |
file.write(f" - {n}: {param.size()}\n") |
|
|
245 |
print(f" - {n}: {param.size()}") |
|
|
246 |
break |
|
|
247 |
file.write(f"Total number of parameters: {num_params}\n") |
|
|
248 |
print(f"Total number of parameters: {num_params}\n\n") |
|
|
249 |
|
|
|
250 |
def restore_model(self, epoch, iteration, model_directory): |
|
|
251 |
"""Restore the trained generator and discriminator.""" |
|
|
252 |
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration)) |
|
|
253 |
|
|
|
254 |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration)) |
|
|
255 |
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration)) |
|
|
256 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) |
|
|
257 |
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) |
|
|
258 |
|
|
|
259 |
def save_model(self, model_directory, idx,i): |
|
|
260 |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1)) |
|
|
261 |
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1)) |
|
|
262 |
torch.save(self.G.state_dict(), G_path) |
|
|
263 |
torch.save(self.D.state_dict(), D_path) |
|
|
264 |
|
|
|
265 |
def reset_grad(self): |
|
|
266 |
"""Reset the gradient buffers.""" |
|
|
267 |
self.g_optimizer.zero_grad() |
|
|
268 |
self.d_optimizer.zero_grad() |
|
|
269 |
|
|
|
270 |
def train(self, config): |
|
|
271 |
''' Training Script starts from here''' |
|
|
272 |
if self.use_wandb: |
|
|
273 |
mode = 'online' if self.online else 'offline' |
|
|
274 |
else: |
|
|
275 |
mode = 'disabled' |
|
|
276 |
kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config, |
|
|
277 |
'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True} |
|
|
278 |
wandb.init(**kwargs) |
|
|
279 |
|
|
|
280 |
wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt")) |
|
|
281 |
wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt")) |
|
|
282 |
|
|
|
283 |
self.model_directory = os.path.join(self.model_save_dir, self.arguments) |
|
|
284 |
self.sample_directory = os.path.join(self.sample_dir, self.arguments) |
|
|
285 |
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments)) |
|
|
286 |
if not os.path.exists(self.model_directory): |
|
|
287 |
os.makedirs(self.model_directory) |
|
|
288 |
if not os.path.exists(self.sample_directory): |
|
|
289 |
os.makedirs(self.sample_directory) |
|
|
290 |
|
|
|
291 |
# smiles data for metrics calculation. |
|
|
292 |
drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()] |
|
|
293 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles] |
|
|
294 |
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None] |
|
|
295 |
|
|
|
296 |
if self.resume: |
|
|
297 |
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory) |
|
|
298 |
|
|
|
299 |
# Start training. |
|
|
300 |
print('Start training...') |
|
|
301 |
self.start_time = time.time() |
|
|
302 |
for idx in range(self.epoch): |
|
|
303 |
# =================================================================================== # |
|
|
304 |
# 1. Preprocess input data # |
|
|
305 |
# =================================================================================== # |
|
|
306 |
# Load the data |
|
|
307 |
dataloader_iterator = iter(self.drugs_loader) |
|
|
308 |
|
|
|
309 |
wandb.log({"epoch": idx}) |
|
|
310 |
|
|
|
311 |
for i, data in enumerate(self.loader): |
|
|
312 |
try: |
|
|
313 |
drugs = next(dataloader_iterator) |
|
|
314 |
except StopIteration: |
|
|
315 |
dataloader_iterator = iter(self.drugs_loader) |
|
|
316 |
drugs = next(dataloader_iterator) |
|
|
317 |
|
|
|
318 |
wandb.log({"iter": i}) |
|
|
319 |
|
|
|
320 |
# Preprocess both dataset |
|
|
321 |
real_graphs, a_tensor, x_tensor = load_molecules( |
|
|
322 |
data=data, |
|
|
323 |
batch_size=self.batch_size, |
|
|
324 |
device=self.device, |
|
|
325 |
b_dim=self.b_dim, |
|
|
326 |
m_dim=self.m_dim, |
|
|
327 |
) |
|
|
328 |
|
|
|
329 |
drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules( |
|
|
330 |
data=drugs, |
|
|
331 |
batch_size=self.batch_size, |
|
|
332 |
device=self.device, |
|
|
333 |
b_dim=self.b_dim, |
|
|
334 |
m_dim=self.m_dim, |
|
|
335 |
) |
|
|
336 |
|
|
|
337 |
# Training configuration. |
|
|
338 |
GEN_node = x_tensor # Generator input node features (annotation matrix of real molecules) |
|
|
339 |
GEN_edge = a_tensor # Generator input edge features (adjacency matrix of real molecules) |
|
|
340 |
if self.submodel == "DrugGEN": |
|
|
341 |
DISC_node = drugs_x_tensor # Discriminator input node features (annotation matrix of drug molecules) |
|
|
342 |
DISC_edge = drugs_a_tensor # Discriminator input edge features (adjacency matrix of drug molecules) |
|
|
343 |
elif self.submodel == "NoTarget": |
|
|
344 |
DISC_node = x_tensor # Discriminator input node features (annotation matrix of real molecules) |
|
|
345 |
DISC_edge = a_tensor # Discriminator input edge features (adjacency matrix of real molecules) |
|
|
346 |
|
|
|
347 |
# =================================================================================== # |
|
|
348 |
# 2. Train the GAN # |
|
|
349 |
# =================================================================================== # |
|
|
350 |
|
|
|
351 |
loss = {} |
|
|
352 |
self.reset_grad() |
|
|
353 |
# Compute discriminator loss. |
|
|
354 |
node, edge, d_loss = discriminator_loss(self.G, |
|
|
355 |
self.D, |
|
|
356 |
DISC_edge, |
|
|
357 |
DISC_node, |
|
|
358 |
GEN_edge, |
|
|
359 |
GEN_node, |
|
|
360 |
self.batch_size, |
|
|
361 |
self.device, |
|
|
362 |
self.lambda_gp) |
|
|
363 |
d_total = d_loss |
|
|
364 |
wandb.log({"d_loss": d_total.item()}) |
|
|
365 |
|
|
|
366 |
loss["d_total"] = d_total.item() |
|
|
367 |
d_total.backward() |
|
|
368 |
self.d_optimizer.step() |
|
|
369 |
|
|
|
370 |
self.reset_grad() |
|
|
371 |
|
|
|
372 |
# Compute generator loss. |
|
|
373 |
generator_output = generator_loss(self.G, |
|
|
374 |
self.D, |
|
|
375 |
GEN_edge, |
|
|
376 |
GEN_node, |
|
|
377 |
self.batch_size) |
|
|
378 |
g_loss, node, edge, node_sample, edge_sample = generator_output |
|
|
379 |
g_total = g_loss |
|
|
380 |
wandb.log({"g_loss": g_total.item()}) |
|
|
381 |
|
|
|
382 |
loss["g_total"] = g_total.item() |
|
|
383 |
g_total.backward() |
|
|
384 |
self.g_optimizer.step() |
|
|
385 |
|
|
|
386 |
# Logging. |
|
|
387 |
if (i+1) % self.log_step == 0: |
|
|
388 |
logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory, |
|
|
389 |
drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol, |
|
|
390 |
self.dataset_name, a_tensor, x_tensor, drug_vecs) |
|
|
391 |
|
|
|
392 |
mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(), |
|
|
393 |
idx, i, self.dataset.matrices2mol, self.dataset_name) |
|
|
394 |
print("samples saved at epoch {} and iteration {}".format(idx,i)) |
|
|
395 |
|
|
|
396 |
self.save_model(self.model_directory, idx, i) |
|
|
397 |
print("model saved at epoch {} and iteration {}".format(idx,i)) |
|
|
398 |
|
|
|
399 |
|
|
|
400 |
if __name__ == '__main__': |
|
|
401 |
parser = argparse.ArgumentParser() |
|
|
402 |
|
|
|
403 |
# Data configuration. |
|
|
404 |
parser.add_argument('--raw_file', type=str, required=True) |
|
|
405 |
parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget') |
|
|
406 |
parser.add_argument('--drug_data_dir', type=str, default='data') |
|
|
407 |
parser.add_argument('--mol_data_dir', type=str, default='data') |
|
|
408 |
parser.add_argument('--features', action='store_true', help='features dimension for nodes') |
|
|
409 |
|
|
|
410 |
# Model configuration. |
|
|
411 |
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget']) |
|
|
412 |
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid']) |
|
|
413 |
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.') |
|
|
414 |
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.') |
|
|
415 |
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.') |
|
|
416 |
parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.') |
|
|
417 |
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.') |
|
|
418 |
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.') |
|
|
419 |
parser.add_argument('--dropout', type=float, default=0., help='dropout rate') |
|
|
420 |
parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator') |
|
|
421 |
parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.') |
|
|
422 |
|
|
|
423 |
# Training configuration. |
|
|
424 |
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.') |
|
|
425 |
parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.') |
|
|
426 |
parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G') |
|
|
427 |
parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D') |
|
|
428 |
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer') |
|
|
429 |
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') |
|
|
430 |
parser.add_argument('--log_dir', type=str, default='experiments/logs') |
|
|
431 |
parser.add_argument('--sample_dir', type=str, default='experiments/samples') |
|
|
432 |
parser.add_argument('--model_save_dir', type=str, default='experiments/models') |
|
|
433 |
parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training') |
|
|
434 |
|
|
|
435 |
# Resume training. |
|
|
436 |
parser.add_argument('--resume', type=bool, default=False, help='resume training') |
|
|
437 |
parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch') |
|
|
438 |
parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step') |
|
|
439 |
parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory') |
|
|
440 |
|
|
|
441 |
# Seed configuration. |
|
|
442 |
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility') |
|
|
443 |
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility') |
|
|
444 |
|
|
|
445 |
# wandb configuration. |
|
|
446 |
parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging') |
|
|
447 |
parser.add_argument('--online', action='store_true', help='use wandb online') |
|
|
448 |
parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name') |
|
|
449 |
parser.add_argument('--parallel', action='store_true', help='Parallelize training') |
|
|
450 |
|
|
|
451 |
config = parser.parse_args() |
|
|
452 |
|
|
|
453 |
# Check if drug_raw_file is provided when using DrugGEN model |
|
|
454 |
if config.submodel == "DrugGEN" and not config.drug_raw_file: |
|
|
455 |
parser.error("--drug_raw_file is required when using DrugGEN model") |
|
|
456 |
|
|
|
457 |
# If using NoTarget model and drug_raw_file is not provided, use a dummy file |
|
|
458 |
if config.submodel == "NoTarget" and not config.drug_raw_file: |
|
|
459 |
config.drug_raw_file = "data/akt_train.smi" # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's) |
|
|
460 |
|
|
|
461 |
trainer = Train(config) |
|
|
462 |
trainer.train(config) |