Diff of /drug_generator.py [000000] .. [a621b4]

Switch to unified view

a b/drug_generator.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Mon May  1 19:41:07 2023
4
5
@author: Sen
6
"""
7
8
9
import os
10
import sys
11
import subprocess
12
import hashlib
13
import warnings
14
import platform
15
import csv
16
import numpy as np
17
from tqdm import tqdm
18
import argparse
19
import torch
20
from transformers import AutoTokenizer, GPT2LMHeadModel
21
import shutil
22
23
from openbabel import openbabel
24
import logging
25
import time
26
import subprocess
27
import threading
28
import os
29
import signal
30
import psutil
31
32
import os
33
34
class Command(object):
35
    def __init__(self, cmd):
36
        self.cmd = cmd
37
        self.process = None
38
39
    def run(self, timeout):
40
        def target():
41
            try:
42
                if os.name == 'posix':  # Unix/Linux/Mac
43
                    self.process = subprocess.Popen(self.cmd, shell=True, stderr=subprocess.DEVNULL,preexec_fn=os.setsid)
44
                else:  # Windows
45
                    self.process = subprocess.Popen(self.cmd, shell=True, stderr=subprocess.DEVNULL)
46
                self.process.communicate()
47
            except Exception:
48
                pass
49
50
        thread = threading.Thread(target=target)
51
        thread.start()
52
53
        thread.join(timeout)
54
        if thread.is_alive():
55
            if os.name == 'posix':  # Unix/Linux/Mac
56
                os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
57
            else:  # Windows
58
                parent = psutil.Process(self.process.pid)
59
                for child in parent.children(recursive=True):
60
                    child.kill()
61
                parent.kill()
62
            thread.join()
63
        return self.process.returncode if self.process else None
64
65
66
class LigandPostprocessor:
67
    def __init__(self, path):
68
        self.hash_ligand_mapping = {}
69
        self.output_path = path  # Output directory for SDF files
70
        self.load_mapping()
71
72
    def load_mapping(self):
73
        mapping_file = os.path.join(output_path, 'hash_ligand_mapping.csv')
74
        if os.path.exists(mapping_file):
75
            print("Found existed mapping file, now reading ...")
76
            with open(mapping_file, 'r') as f:
77
                reader = csv.reader(f)
78
                for row in reader:
79
                    self.hash_ligand_mapping[row[0]] = row[1]
80
81
    # Define a function to save the hash-ligand mapping to a file
82
    def save_mapping(self):
83
        mapping_file = os.path.join(output_path, 'hash_ligand_mapping.csv')
84
        with open(mapping_file, 'w', newline='') as f:
85
            writer = csv.writer(f)
86
            for ligand_hash, ligand in self.hash_ligand_mapping.items():
87
                writer.writerow([ligand_hash, ligand])
88
89
    # Define a function to filter out empty SDF files
90
    def filter_sdf(self, hash_ligand_mapping_per_batch):
91
        print("Filtering sdf ...")
92
        ligand_hash_list = list(hash_ligand_mapping_per_batch.keys())
93
        mapping_per_match = hash_ligand_mapping_per_batch.copy()
94
        for ligand_hash in tqdm(ligand_hash_list):
95
            filepath = os.path.join(self.output_path, ligand_hash + '.sdf')            
96
            if os.path.getsize(filepath) < 2*1024:  #2kb
97
                try:
98
                    os.remove(filepath)
99
                    #mapping_per_match.pop(ligand_hash)
100
                except Exception:
101
                    print(filepath)
102
                mapping_per_match.pop(ligand_hash)    
103
        return mapping_per_match
104
105
    # Define a function to generate SDF files from a list of ligand SMILES using OpenBabel
106
    def to_sdf(self, ligand_list_per_batch):
107
        print("Converting to sdf ...")
108
        hash_ligand_mapping_per_batch = {}
109
        for ligand in tqdm(ligand_list_per_batch):  
110
            
111
            obConversion = openbabel.OBConversion()
112
            obConversion.SetInAndOutFormats("smi", "smi")
113
            mol = openbabel.OBMol()
114
            if not obConversion.ReadString(mol, ligand):
115
                continue  # Skip invalid SMILES
116
            
117
            num_atoms = sum(1 for atom in openbabel.OBMolAtomIter(mol) if atom.GetAtomicNum() != 1)
118
            if min_atoms is not None and num_atoms < min_atoms:
119
                continue  # Skip molecules with too few non-hydrogen atoms
120
            if max_atoms is not None and num_atoms > max_atoms:
121
                continue  # Skip molecules with too many non-hydrogen atoms
122
            
123
            ligand_hash = hashlib.sha1(ligand.encode()).hexdigest()
124
            if ligand_hash not in self.hash_ligand_mapping.keys():
125
                filepath = os.path.join(self.output_path , ligand_hash + '.sdf')
126
                
127
                if platform.system() == "Windows":
128
                    cmd = "obabel -:" + ligand + " -osdf -O " + filepath + " --gen3d --forcefield mmff94"
129
                elif platform.system() == "Linux":
130
                    obabel_path = shutil.which('obabel')
131
                    cmd = f"{obabel_path} -:'{ligand}' -osdf -O '{filepath}' --gen3d --forcefield mmff94"
132
                else:pass
133
134
                try:
135
                    command = Command(cmd)
136
                    return_code = command.run(timeout=10)
137
                    if return_code != 0:  # Check the return value
138
                        #print(f"Command execution failed with return code: {return_code}")
139
                        continue  # Skip the current iteration if the command execution failed
140
                except Exception:
141
                    time.sleep(1)
142
                    continue
143
                    
144
                if os.path.exists(filepath):
145
                    hash_ligand_mapping_per_batch[ligand_hash] = ligand  # Add the hash-ligand mapping to the dictionary
146
        self.hash_ligand_mapping.update(self.filter_sdf(hash_ligand_mapping_per_batch))
147
    
148
    def delete_empty_files(self):
149
    # 遍历指定目录及其子目录中的所有文件
150
        for foldername, subfolders, filenames in os.walk(self.output_path):
151
            for filename in filenames:
152
                file_path = os.path.join(foldername, filename)
153
                # 如果文件大小为0,则删除该文件
154
                if os.path.getsize(file_path) < 2*1024:  #2kb
155
                    try:
156
                        os.remove(file_path)
157
                        print(f'Deleted {file_path}')
158
                    except Exception:
159
                        pass 
160
    
161
    
162
    def check_sdf(self):
163
        file_list = os.listdir(self.output_path)
164
        sdf_file_list = [x for x in file_list if x[-4:]=='sdf']
165
        for filename in sdf_file_list:
166
            hash_ = filename[:-4]
167
            if hash_ not in self.hash_ligand_mapping.keys():
168
                filepath = os.path.join(self.output_path,filename)
169
                try:
170
                    os.remove(filepath)
171
                    print('remove ' + filepath)
172
                except Exception:
173
                    pass
174
            else:pass    
175
                
176
               
177
                
178
    
179
def about():
180
    print("""
181
  _____                    _____ _____ _______ 
182
 |  __ \                  / ____|  __ \__   __|
183
 | |  | |_ __ _   _  __ _| |  __| |__) | | |   
184
 | |  | | '__| | | |/ _` | | |_ |  ___/  | |   
185
 | |__| | |  | |_| | (_| | |__| | |      | |   
186
 |_____/|_|   \__,_|\__, |\_____|_|      |_|   
187
                     __/ |                     
188
                    |___/                      
189
 A generative drug design model based on GPT2
190
    """)
191
192
193
# Function to read in FASTA file
194
def read_fasta_file(file_path):
195
    with open(file_path, 'r') as f:
196
        sequence = []
197
198
        for line in f:
199
            line = line.strip()
200
            if not line.startswith('>'):
201
                sequence.append(line)
202
203
        protein_sequence = ''.join(sequence)
204
    return protein_sequence
205
206
207
                    
208
if __name__ == "__main__":
209
    about()
210
    warnings.filterwarnings('ignore')
211
    
212
    if platform.system() == "Linux":
213
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
214
    
215
    #Sometimes, using Hugging Face may require a proxy.
216
    #os.environ["http_proxy"] = "http://your.proxy.server:port"
217
    #os.environ["https_proxy"] = "http://your.proxy.server:port"
218
219
    # Set up command line argument parsing
220
    parser = argparse.ArgumentParser()
221
    parser.add_argument('-p','--pro_seq', type=str, default=None, help='Input a protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.')
222
    parser.add_argument('-f','--fasta', type=str, default=None, help='Input a FASTA file. Default value is None. Only one of -p and -f should be specified.')
223
    parser.add_argument('-l','--ligand_prompt', type=str, default='', help='Input a ligand prompt. Default value is an empty string.')
224
    parser.add_argument('-e','--empty_input', action='store_true', default=False, help='Enable directly generate mode.')
225
    parser.add_argument('-n','--number',type=int, default=100, help='At least how many molecules will be generated. Default value is 100.')
226
    parser.add_argument('-d','--device',type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.")
227
    parser.add_argument('-o','--output', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.")
228
    parser.add_argument('-b','--batch_size', type=int, default=16, help="How many molecules will be generated per batch. Try to reduce this value if you have low RAM. Default value is 16.")
229
    parser.add_argument('-t','--temperature', type=float, default=1.0, help="Adjusts the randomness of text generation; higher values produce more diverse outputs. Default value is 1.0.")
230
    parser.add_argument('--top_k', type=int, default=9, help='The number of highest probability tokens to consider for top-k sampling. Defaults to 9.')
231
    parser.add_argument('--top_p', type=float, default=0.9, help='The cumulative probability threshold (0.0 - 1.0) for top-p (nucleus) sampling. It defines the minimum subset of tokens to consider for random sampling. Defaults to 0.9.')
232
    parser.add_argument('--min_atoms', type=int, default=None, help='Minimum number of non-H atoms allowed for generation.')
233
    parser.add_argument('--max_atoms', type=int, default=35, help='Maximum number of non-H atoms allowed for generation. Default value is 35.')
234
    parser.add_argument('--no_limit', action='store_true', default=False, help='Disable the default max atoms limit.')
235
236
237
    args = parser.parse_args()
238
    protein_seq = args.pro_seq
239
    fasta_file = args.fasta
240
    ligand_prompt = args.ligand_prompt
241
    directly_gen = args.empty_input
242
    num_generated = args.number
243
    device = args.device
244
    output_path = args.output
245
    batch_generated_size = args.batch_size
246
    temperature_value = args.temperature
247
    top_k = args.top_k
248
    top_p = args.top_p
249
    min_atoms = args.min_atoms
250
    max_atoms = args.max_atoms
251
252
    if args.no_limit:
253
        max_atoms = None
254
    
255
    if (args.min_atoms is not None) and (args.max_atoms is not None) and (args.min_atoms > args.max_atoms):
256
        raise ValueError("Error: min_atoms cannot be greater than max_atoms.")
257
    
258
    if args.ligand_prompt:
259
        args.max_atoms = None
260
        args.min_atoms = None
261
        print("Note: --ligand_prompt is specified. --max_atoms and --min_atoms settings will be ignored.")
262
    
263
    logging.basicConfig(level=logging.CRITICAL)
264
    openbabel.obErrorLog.StopLogging()
265
    os.makedirs(output_path, exist_ok=True)
266
    # Check if the input is either a protein amino acid sequence or a FASTA file, but not both
267
    if directly_gen:
268
        print("Now in directly generate mode.")
269
        prompt = "<|startoftext|><P>"
270
        print(prompt)
271
    else:
272
        if (not protein_seq) and (not fasta_file):
273
            print("Error: Input is empty.")
274
            sys.exit(1)
275
        if protein_seq and fasta_file:
276
            print("Error: The input should be either a protein amino acid sequence or a FASTA file, but not both.")
277
            sys.exit(1)
278
        if fasta_file:
279
            protein_seq = read_fasta_file(fasta_file)
280
        # Generate a prompt for the model
281
        p_prompt = "<|startoftext|><P>" + protein_seq + "<L>"
282
        l_prompt = "" + ligand_prompt
283
        prompt = p_prompt + l_prompt
284
        print(prompt)
285
286
287
    # Load the tokenizer and the model
288
    tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt')
289
    model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt")
290
291
292
    model.eval()
293
    device = torch.device(device)
294
    model.to(device)
295
296
    # Create a LigandPostprocessor object
297
    ligand_post_processor = LigandPostprocessor(output_path)
298
299
    # Generate molecules
300
    generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
301
    generated = generated.to(device)
302
303
    batch_number = 0
304
305
    directly_gen_protein_list = []
306
    directly_gen_ligand_list = []
307
    
308
309
    attention_mask = generated.ne(tokenizer.pad_token_id).float()
310
    while len(ligand_post_processor.hash_ligand_mapping) < num_generated:
311
        generate_ligand_list = []
312
        batch_number += 1
313
        print(f"=====Batch {batch_number}=====")
314
        print("Generating ligand SMILES ...")
315
        sample_outputs = model.generate(
316
            generated,
317
            do_sample=True,
318
            top_k=top_k,
319
            max_length=1024,
320
            top_p=top_p,
321
            temperature=temperature_value,
322
            num_return_sequences=batch_generated_size, 
323
            attention_mask=attention_mask,
324
            pad_token_id = tokenizer.eos_token_id
325
        )
326
        for sample_output in sample_outputs:
327
            generate_ligand = tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1]
328
            generate_ligand_list.append(generate_ligand)
329
            if directly_gen:
330
                directly_gen_protein_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[0])
331
                directly_gen_ligand_list.append(generate_ligand)
332
        torch.cuda.empty_cache()
333
        ligand_post_processor.to_sdf(generate_ligand_list)
334
        ligand_post_processor.delete_empty_files()
335
        ligand_post_processor.check_sdf()
336
        
337
    if directly_gen:
338
        arr = np.array([directly_gen_protein_list, directly_gen_ligand_list])
339
        processed_ligand_list = ligand_post_processor.hash_ligand_mapping.values()
340
        with open(os.path.join(output_path, 'generate_directly.csv'), 'w', newline='') as f:
341
            writer = csv.writer(f)
342
            for index in range(arr.shape[1]):
343
                protein, ligand = arr[0, index], arr[1, index]
344
                if ligand in processed_ligand_list:
345
                    writer.writerow([protein, ligand])
346
347
    print("Saving mapping file ...")
348
    ligand_post_processor.save_mapping()
349
    print(f"{len(ligand_post_processor.hash_ligand_mapping)} molecules successfully generated!")
350
351
    print("Ligand Energy Minimization")
352
    result = subprocess.run(['python', 'druggpt_min_multi.py', '-d', output_path])