--- a +++ b/examples/vae_example/chemistry_vae.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +SELFIES: a robust representation of semantically constrained graphs with an + example application in chemistry (https://arxiv.org/abs/1905.13741) + by Mario Krenn, Florian Haese, AkshatKuman Nigam, Pascal Friederich, + Alan Aspuru-Guzik. + + Variational Autoencoder (VAE) for chemistry + comparing SMILES and SELFIES representation using reconstruction + quality, diversity and latent space validity as metrics of + interest + +information: + ML framework: pytorch + chemistry framework: RDKit + + get_selfie_and_smiles_encodings_for_dataset + generate complete encoding (inclusive alphabet) for SMILES and + SELFIES given a data file + + VAEEncoder + fully connected, 3 layer neural network - encodes a one-hot + representation of molecule (in SMILES or SELFIES representation) + to latent space + + VAEDecoder + decodes point in latent space using an RNN + + latent_space_quality + samples points from latent space, decodes them into molecules, + calculates chemical validity (using RDKit's MolFromSmiles), calculates + diversity +""" + +import os +import sys +import time + +import numpy as np +import pandas as pd +import torch +import yaml +from rdkit import rdBase +from rdkit.Chem import MolFromSmiles +from torch import nn + +import selfies as sf +from data_loader import \ + multiple_selfies_to_hot, multiple_smile_to_hot + +rdBase.DisableLog('rdApp.error') + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _make_dir(directory): + os.makedirs(directory) + + +def save_models(encoder, decoder, epoch): + out_dir = './saved_models/{}'.format(epoch) + _make_dir(out_dir) + torch.save(encoder, '{}/E'.format(out_dir)) + torch.save(decoder, '{}/D'.format(out_dir)) + + +class VAEEncoder(nn.Module): + + def __init__(self, in_dimension, layer_1d, layer_2d, layer_3d, + latent_dimension): + """ + Fully Connected layers to encode molecule to latent space + """ + super(VAEEncoder, self).__init__() + self.latent_dimension = latent_dimension + + # Reduce dimension up to second last layer of Encoder + self.encode_nn = nn.Sequential( + nn.Linear(in_dimension, layer_1d), + nn.ReLU(), + nn.Linear(layer_1d, layer_2d), + nn.ReLU(), + nn.Linear(layer_2d, layer_3d), + nn.ReLU() + ) + + # Latent space mean + self.encode_mu = nn.Linear(layer_3d, latent_dimension) + + # Latent space variance + self.encode_log_var = nn.Linear(layer_3d, latent_dimension) + + @staticmethod + def reparameterize(mu, log_var): + """ + This trick is explained well here: + https://stats.stackexchange.com/a/16338 + """ + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps.mul(std).add_(mu) + + def forward(self, x): + """ + Pass throught the Encoder + """ + # Get results of encoder network + h1 = self.encode_nn(x) + + # latent space + mu = self.encode_mu(h1) + log_var = self.encode_log_var(h1) + + # Reparameterize + z = self.reparameterize(mu, log_var) + return z, mu, log_var + + +class VAEDecoder(nn.Module): + + def __init__(self, latent_dimension, gru_stack_size, gru_neurons_num, + out_dimension): + """ + Through Decoder + """ + super(VAEDecoder, self).__init__() + self.latent_dimension = latent_dimension + self.gru_stack_size = gru_stack_size + self.gru_neurons_num = gru_neurons_num + + # Simple Decoder + self.decode_RNN = nn.GRU( + input_size=latent_dimension, + hidden_size=gru_neurons_num, + num_layers=gru_stack_size, + batch_first=False) + + self.decode_FC = nn.Sequential( + nn.Linear(gru_neurons_num, out_dimension), + ) + + def init_hidden(self, batch_size=1): + weight = next(self.parameters()) + return weight.new_zeros(self.gru_stack_size, batch_size, + self.gru_neurons_num) + + def forward(self, z, hidden): + """ + A forward pass throught the entire model. + """ + + # Decode + l1, hidden = self.decode_RNN(z, hidden) + decoded = self.decode_FC(l1) # fully connected layer + + return decoded, hidden + + +def is_correct_smiles(smiles): + """ + Using RDKit to calculate whether molecule is syntactically and + semantically valid. + """ + if smiles == "": + return False + + try: + return MolFromSmiles(smiles, sanitize=True) is not None + except Exception: + return False + + +def sample_latent_space(vae_encoder, vae_decoder, sample_len): + vae_encoder.eval() + vae_decoder.eval() + + gathered_atoms = [] + + fancy_latent_point = torch.randn(1, 1, vae_encoder.latent_dimension, + device=device) + hidden = vae_decoder.init_hidden() + + # runs over letters from molecules (len=size of largest molecule) + for _ in range(sample_len): + out_one_hot, hidden = vae_decoder(fancy_latent_point, hidden) + + out_one_hot = out_one_hot.flatten().detach() + soft = nn.Softmax(0) + out_one_hot = soft(out_one_hot) + + out_index = out_one_hot.argmax(0) + gathered_atoms.append(out_index.data.cpu().tolist()) + + vae_encoder.train() + vae_decoder.train() + + return gathered_atoms + + +def latent_space_quality(vae_encoder, vae_decoder, type_of_encoding, + alphabet, sample_num, sample_len): + total_correct = 0 + all_correct_molecules = set() + print(f"latent_space_quality:" + f" Take {sample_num} samples from the latent space") + + for _ in range(1, sample_num + 1): + + molecule_pre = '' + for i in sample_latent_space(vae_encoder, vae_decoder, sample_len): + molecule_pre += alphabet[i] + molecule = molecule_pre.replace(' ', '') + + if type_of_encoding == 1: # if SELFIES, decode to SMILES + molecule = sf.decoder(molecule) + + if is_correct_smiles(molecule): + total_correct += 1 + all_correct_molecules.add(molecule) + + return total_correct, len(all_correct_molecules) + + +def quality_in_valid_set(vae_encoder, vae_decoder, data_valid, batch_size): + data_valid = data_valid[torch.randperm(data_valid.size()[0])] # shuffle + num_batches_valid = len(data_valid) // batch_size + + quality_list = [] + for batch_iteration in range(min(25, num_batches_valid)): + + # get batch + start_idx = batch_iteration * batch_size + stop_idx = (batch_iteration + 1) * batch_size + batch = data_valid[start_idx: stop_idx] + _, trg_len, _ = batch.size() + + inp_flat_one_hot = batch.flatten(start_dim=1) + latent_points, mus, log_vars = vae_encoder(inp_flat_one_hot) + + latent_points = latent_points.unsqueeze(0) + hidden = vae_decoder.init_hidden(batch_size=batch_size) + out_one_hot = torch.zeros_like(batch, device=device) + for seq_index in range(trg_len): + out_one_hot_line, hidden = vae_decoder(latent_points, hidden) + out_one_hot[:, seq_index, :] = out_one_hot_line[0] + + # assess reconstruction quality + quality = compute_recon_quality(batch, out_one_hot) + quality_list.append(quality) + + return np.mean(quality_list).item() + + +def train_model(vae_encoder, vae_decoder, + data_train, data_valid, num_epochs, batch_size, + lr_enc, lr_dec, KLD_alpha, + sample_num, sample_len, alphabet, type_of_encoding): + """ + Train the Variational Auto-Encoder + """ + + print('num_epochs: ', num_epochs) + + # initialize an instance of the model + optimizer_encoder = torch.optim.Adam(vae_encoder.parameters(), lr=lr_enc) + optimizer_decoder = torch.optim.Adam(vae_decoder.parameters(), lr=lr_dec) + + data_train = data_train.clone().detach().to(device) + num_batches_train = int(len(data_train) / batch_size) + + quality_valid_list = [0, 0, 0, 0] + for epoch in range(num_epochs): + + data_train = data_train[torch.randperm(data_train.size()[0])] + + start = time.time() + for batch_iteration in range(num_batches_train): # batch iterator + + # manual batch iterations + start_idx = batch_iteration * batch_size + stop_idx = (batch_iteration + 1) * batch_size + batch = data_train[start_idx: stop_idx] + + # reshaping for efficient parallelization + inp_flat_one_hot = batch.flatten(start_dim=1) + latent_points, mus, log_vars = vae_encoder(inp_flat_one_hot) + + # initialization hidden internal state of RNN (RNN has two inputs + # and two outputs:) + # input: latent space & hidden state + # output: one-hot encoding of one character of molecule & hidden + # state the hidden state acts as the internal memory + latent_points = latent_points.unsqueeze(0) + hidden = vae_decoder.init_hidden(batch_size=batch_size) + + # decoding from RNN N times, where N is the length of the largest + # molecule (all molecules are padded) + out_one_hot = torch.zeros_like(batch, device=device) + for seq_index in range(batch.shape[1]): + out_one_hot_line, hidden = vae_decoder(latent_points, hidden) + out_one_hot[:, seq_index, :] = out_one_hot_line[0] + + # compute ELBO + loss = compute_elbo(batch, out_one_hot, mus, log_vars, KLD_alpha) + + # perform back propogation + optimizer_encoder.zero_grad() + optimizer_decoder.zero_grad() + loss.backward(retain_graph=True) + nn.utils.clip_grad_norm_(vae_decoder.parameters(), 0.5) + optimizer_encoder.step() + optimizer_decoder.step() + + if batch_iteration % 30 == 0: + end = time.time() + + # assess reconstruction quality + quality_train = compute_recon_quality(batch, out_one_hot) + quality_valid = quality_in_valid_set(vae_encoder, vae_decoder, + data_valid, batch_size) + + report = 'Epoch: %d, Batch: %d / %d,\t(loss: %.4f\t| ' \ + 'quality: %.4f | quality_valid: %.4f)\t' \ + 'ELAPSED TIME: %.5f' \ + % (epoch, batch_iteration, num_batches_train, + loss.item(), quality_train, quality_valid, + end - start) + print(report) + start = time.time() + + quality_valid = quality_in_valid_set(vae_encoder, vae_decoder, + data_valid, batch_size) + quality_valid_list.append(quality_valid) + + # only measure validity of reconstruction improved + quality_increase = len(quality_valid_list) \ + - np.argmax(quality_valid_list) + if quality_increase == 1 and quality_valid_list[-1] > 50.: + corr, unique = latent_space_quality(vae_encoder, vae_decoder, + type_of_encoding, alphabet, + sample_num, sample_len) + else: + corr, unique = -1., -1. + + report = 'Validity: %.5f %% | Diversity: %.5f %% | ' \ + 'Reconstruction: %.5f %%' \ + % (corr * 100. / sample_num, unique * 100. / sample_num, + quality_valid) + print(report) + + with open('results.dat', 'a') as content: + content.write(report + '\n') + + if quality_valid_list[-1] < 70. and epoch > 200: + break + + if quality_increase > 20: + print('Early stopping criteria') + break + + +def compute_elbo(x, x_hat, mus, log_vars, KLD_alpha): + inp = x_hat.reshape(-1, x_hat.shape[2]) + target = x.reshape(-1, x.shape[2]).argmax(1) + + criterion = torch.nn.CrossEntropyLoss() + recon_loss = criterion(inp, target) + kld = -0.5 * torch.mean(1. + log_vars - mus.pow(2) - log_vars.exp()) + + return recon_loss + KLD_alpha * kld + + +def compute_recon_quality(x, x_hat): + x_indices = x.reshape(-1, x.shape[2]).argmax(1) + x_hat_indices = x_hat.reshape(-1, x_hat.shape[2]).argmax(1) + + differences = 1. - torch.abs(x_hat_indices - x_indices) + differences = torch.clamp(differences, min=0., max=1.).double() + quality = 100. * torch.mean(differences) + quality = quality.detach().cpu().numpy() + + return quality + + +def get_selfie_and_smiles_encodings_for_dataset(file_path): + """ + Returns encoding, alphabet and length of largest molecule in SMILES and + SELFIES, given a file containing SMILES molecules. + + input: + csv file with molecules. Column's name must be 'smiles'. + output: + - selfies encoding + - selfies alphabet + - longest selfies string + - smiles encoding (equivalent to file content) + - smiles alphabet (character based) + - longest smiles string + """ + + df = pd.read_csv(file_path) + + smiles_list = np.asanyarray(df.smiles) + + smiles_alphabet = list(set(''.join(smiles_list))) + smiles_alphabet.append(' ') # for padding + + largest_smiles_len = len(max(smiles_list, key=len)) + + print('--> Translating SMILES to SELFIES...') + selfies_list = list(map(sf.encoder, smiles_list)) + + all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list) + all_selfies_symbols.add('[nop]') + selfies_alphabet = list(all_selfies_symbols) + + largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list) + + print('Finished translating SMILES to SELFIES.') + + return selfies_list, selfies_alphabet, largest_selfies_len, \ + smiles_list, smiles_alphabet, largest_smiles_len + + +def main(): + content = open('logfile.dat', 'w') + content.close() + content = open('results.dat', 'w') + content.close() + + if os.path.exists("settings.yml"): + settings = yaml.safe_load(open("settings.yml", "r")) + else: + print("Expected a file settings.yml but didn't find it.") + return + + print('--> Acquiring data...') + type_of_encoding = settings['data']['type_of_encoding'] + file_name_smiles = settings['data']['smiles_file'] + + print('Finished acquiring data.') + + if type_of_encoding == 0: + print('Representation: SMILES') + _, _, _, encoding_list, encoding_alphabet, largest_molecule_len = \ + get_selfie_and_smiles_encodings_for_dataset(file_name_smiles) + + print('--> Creating one-hot encoding...') + data = multiple_smile_to_hot(encoding_list, largest_molecule_len, + encoding_alphabet) + print('Finished creating one-hot encoding.') + + elif type_of_encoding == 1: + print('Representation: SELFIES') + encoding_list, encoding_alphabet, largest_molecule_len, _, _, _ = \ + get_selfie_and_smiles_encodings_for_dataset(file_name_smiles) + + print('--> Creating one-hot encoding...') + data = multiple_selfies_to_hot(encoding_list, largest_molecule_len, + encoding_alphabet) + print('Finished creating one-hot encoding.') + + else: + print("type_of_encoding not in {0, 1}.") + return + + len_max_molec = data.shape[1] + len_alphabet = data.shape[2] + len_max_mol_one_hot = len_max_molec * len_alphabet + + print(' ') + print(f"Alphabet has {len_alphabet} letters, " + f"largest molecule is {len_max_molec} letters.") + + data_parameters = settings['data'] + batch_size = data_parameters['batch_size'] + + encoder_parameter = settings['encoder'] + decoder_parameter = settings['decoder'] + training_parameters = settings['training'] + + vae_encoder = VAEEncoder(in_dimension=len_max_mol_one_hot, + **encoder_parameter).to(device) + vae_decoder = VAEDecoder(**decoder_parameter, + out_dimension=len(encoding_alphabet)).to(device) + + print('*' * 15, ': -->', device) + + data = torch.tensor(data, dtype=torch.float).to(device) + + train_valid_test_size = [0.5, 0.5, 0.0] + data = data[torch.randperm(data.size()[0])] + idx_train_val = int(len(data) * train_valid_test_size[0]) + idx_val_test = idx_train_val + int(len(data) * train_valid_test_size[1]) + + data_train = data[0:idx_train_val] + data_valid = data[idx_train_val:idx_val_test] + + print("start training") + train_model(**training_parameters, + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + batch_size=batch_size, + data_train=data_train, + data_valid=data_valid, + alphabet=encoding_alphabet, + type_of_encoding=type_of_encoding, + sample_len=len_max_molec) + + with open('COMPLETED', 'w') as content: + content.write('exit code: 0') + + +if __name__ == '__main__': + try: + main() + except AttributeError: + _, error_message, _ = sys.exc_info() + print(error_message)