|
a |
|
b/drug_generator.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Mon May 1 19:41:07 2023 |
|
|
4 |
|
|
|
5 |
@author: Sen |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
import os |
|
|
10 |
import sys |
|
|
11 |
import subprocess |
|
|
12 |
import hashlib |
|
|
13 |
import warnings |
|
|
14 |
import platform |
|
|
15 |
import csv |
|
|
16 |
import numpy as np |
|
|
17 |
from tqdm import tqdm |
|
|
18 |
import argparse |
|
|
19 |
import torch |
|
|
20 |
from transformers import AutoTokenizer, GPT2LMHeadModel |
|
|
21 |
import shutil |
|
|
22 |
|
|
|
23 |
from openbabel import openbabel |
|
|
24 |
import logging |
|
|
25 |
import time |
|
|
26 |
import subprocess |
|
|
27 |
import threading |
|
|
28 |
import os |
|
|
29 |
import signal |
|
|
30 |
import psutil |
|
|
31 |
|
|
|
32 |
import os |
|
|
33 |
|
|
|
34 |
class Command(object): |
|
|
35 |
def __init__(self, cmd): |
|
|
36 |
self.cmd = cmd |
|
|
37 |
self.process = None |
|
|
38 |
|
|
|
39 |
def run(self, timeout): |
|
|
40 |
def target(): |
|
|
41 |
try: |
|
|
42 |
if os.name == 'posix': # Unix/Linux/Mac |
|
|
43 |
self.process = subprocess.Popen(self.cmd, shell=True, stderr=subprocess.DEVNULL,preexec_fn=os.setsid) |
|
|
44 |
else: # Windows |
|
|
45 |
self.process = subprocess.Popen(self.cmd, shell=True, stderr=subprocess.DEVNULL) |
|
|
46 |
self.process.communicate() |
|
|
47 |
except Exception: |
|
|
48 |
pass |
|
|
49 |
|
|
|
50 |
thread = threading.Thread(target=target) |
|
|
51 |
thread.start() |
|
|
52 |
|
|
|
53 |
thread.join(timeout) |
|
|
54 |
if thread.is_alive(): |
|
|
55 |
if os.name == 'posix': # Unix/Linux/Mac |
|
|
56 |
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) |
|
|
57 |
else: # Windows |
|
|
58 |
parent = psutil.Process(self.process.pid) |
|
|
59 |
for child in parent.children(recursive=True): |
|
|
60 |
child.kill() |
|
|
61 |
parent.kill() |
|
|
62 |
thread.join() |
|
|
63 |
return self.process.returncode if self.process else None |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
class LigandPostprocessor: |
|
|
67 |
def __init__(self, path): |
|
|
68 |
self.hash_ligand_mapping = {} |
|
|
69 |
self.output_path = path # Output directory for SDF files |
|
|
70 |
self.load_mapping() |
|
|
71 |
|
|
|
72 |
def load_mapping(self): |
|
|
73 |
mapping_file = os.path.join(output_path, 'hash_ligand_mapping.csv') |
|
|
74 |
if os.path.exists(mapping_file): |
|
|
75 |
print("Found existed mapping file, now reading ...") |
|
|
76 |
with open(mapping_file, 'r') as f: |
|
|
77 |
reader = csv.reader(f) |
|
|
78 |
for row in reader: |
|
|
79 |
self.hash_ligand_mapping[row[0]] = row[1] |
|
|
80 |
|
|
|
81 |
# Define a function to save the hash-ligand mapping to a file |
|
|
82 |
def save_mapping(self): |
|
|
83 |
mapping_file = os.path.join(output_path, 'hash_ligand_mapping.csv') |
|
|
84 |
with open(mapping_file, 'w', newline='') as f: |
|
|
85 |
writer = csv.writer(f) |
|
|
86 |
for ligand_hash, ligand in self.hash_ligand_mapping.items(): |
|
|
87 |
writer.writerow([ligand_hash, ligand]) |
|
|
88 |
|
|
|
89 |
# Define a function to filter out empty SDF files |
|
|
90 |
def filter_sdf(self, hash_ligand_mapping_per_batch): |
|
|
91 |
print("Filtering sdf ...") |
|
|
92 |
ligand_hash_list = list(hash_ligand_mapping_per_batch.keys()) |
|
|
93 |
mapping_per_match = hash_ligand_mapping_per_batch.copy() |
|
|
94 |
for ligand_hash in tqdm(ligand_hash_list): |
|
|
95 |
filepath = os.path.join(self.output_path, ligand_hash + '.sdf') |
|
|
96 |
if os.path.getsize(filepath) < 2*1024: #2kb |
|
|
97 |
try: |
|
|
98 |
os.remove(filepath) |
|
|
99 |
#mapping_per_match.pop(ligand_hash) |
|
|
100 |
except Exception: |
|
|
101 |
print(filepath) |
|
|
102 |
mapping_per_match.pop(ligand_hash) |
|
|
103 |
return mapping_per_match |
|
|
104 |
|
|
|
105 |
# Define a function to generate SDF files from a list of ligand SMILES using OpenBabel |
|
|
106 |
def to_sdf(self, ligand_list_per_batch): |
|
|
107 |
print("Converting to sdf ...") |
|
|
108 |
hash_ligand_mapping_per_batch = {} |
|
|
109 |
for ligand in tqdm(ligand_list_per_batch): |
|
|
110 |
|
|
|
111 |
obConversion = openbabel.OBConversion() |
|
|
112 |
obConversion.SetInAndOutFormats("smi", "smi") |
|
|
113 |
mol = openbabel.OBMol() |
|
|
114 |
if not obConversion.ReadString(mol, ligand): |
|
|
115 |
continue # Skip invalid SMILES |
|
|
116 |
|
|
|
117 |
num_atoms = sum(1 for atom in openbabel.OBMolAtomIter(mol) if atom.GetAtomicNum() != 1) |
|
|
118 |
if min_atoms is not None and num_atoms < min_atoms: |
|
|
119 |
continue # Skip molecules with too few non-hydrogen atoms |
|
|
120 |
if max_atoms is not None and num_atoms > max_atoms: |
|
|
121 |
continue # Skip molecules with too many non-hydrogen atoms |
|
|
122 |
|
|
|
123 |
ligand_hash = hashlib.sha1(ligand.encode()).hexdigest() |
|
|
124 |
if ligand_hash not in self.hash_ligand_mapping.keys(): |
|
|
125 |
filepath = os.path.join(self.output_path , ligand_hash + '.sdf') |
|
|
126 |
|
|
|
127 |
if platform.system() == "Windows": |
|
|
128 |
cmd = "obabel -:" + ligand + " -osdf -O " + filepath + " --gen3d --forcefield mmff94" |
|
|
129 |
elif platform.system() == "Linux": |
|
|
130 |
obabel_path = shutil.which('obabel') |
|
|
131 |
cmd = f"{obabel_path} -:'{ligand}' -osdf -O '{filepath}' --gen3d --forcefield mmff94" |
|
|
132 |
else:pass |
|
|
133 |
|
|
|
134 |
try: |
|
|
135 |
command = Command(cmd) |
|
|
136 |
return_code = command.run(timeout=10) |
|
|
137 |
if return_code != 0: # Check the return value |
|
|
138 |
#print(f"Command execution failed with return code: {return_code}") |
|
|
139 |
continue # Skip the current iteration if the command execution failed |
|
|
140 |
except Exception: |
|
|
141 |
time.sleep(1) |
|
|
142 |
continue |
|
|
143 |
|
|
|
144 |
if os.path.exists(filepath): |
|
|
145 |
hash_ligand_mapping_per_batch[ligand_hash] = ligand # Add the hash-ligand mapping to the dictionary |
|
|
146 |
self.hash_ligand_mapping.update(self.filter_sdf(hash_ligand_mapping_per_batch)) |
|
|
147 |
|
|
|
148 |
def delete_empty_files(self): |
|
|
149 |
# 遍历指定目录及其子目录中的所有文件 |
|
|
150 |
for foldername, subfolders, filenames in os.walk(self.output_path): |
|
|
151 |
for filename in filenames: |
|
|
152 |
file_path = os.path.join(foldername, filename) |
|
|
153 |
# 如果文件大小为0,则删除该文件 |
|
|
154 |
if os.path.getsize(file_path) < 2*1024: #2kb |
|
|
155 |
try: |
|
|
156 |
os.remove(file_path) |
|
|
157 |
print(f'Deleted {file_path}') |
|
|
158 |
except Exception: |
|
|
159 |
pass |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def check_sdf(self): |
|
|
163 |
file_list = os.listdir(self.output_path) |
|
|
164 |
sdf_file_list = [x for x in file_list if x[-4:]=='sdf'] |
|
|
165 |
for filename in sdf_file_list: |
|
|
166 |
hash_ = filename[:-4] |
|
|
167 |
if hash_ not in self.hash_ligand_mapping.keys(): |
|
|
168 |
filepath = os.path.join(self.output_path,filename) |
|
|
169 |
try: |
|
|
170 |
os.remove(filepath) |
|
|
171 |
print('remove ' + filepath) |
|
|
172 |
except Exception: |
|
|
173 |
pass |
|
|
174 |
else:pass |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
|
|
|
178 |
|
|
|
179 |
def about(): |
|
|
180 |
print(""" |
|
|
181 |
_____ _____ _____ _______ |
|
|
182 |
| __ \ / ____| __ \__ __| |
|
|
183 |
| | | |_ __ _ _ __ _| | __| |__) | | | |
|
|
184 |
| | | | '__| | | |/ _` | | |_ | ___/ | | |
|
|
185 |
| |__| | | | |_| | (_| | |__| | | | | |
|
|
186 |
|_____/|_| \__,_|\__, |\_____|_| |_| |
|
|
187 |
__/ | |
|
|
188 |
|___/ |
|
|
189 |
A generative drug design model based on GPT2 |
|
|
190 |
""") |
|
|
191 |
|
|
|
192 |
|
|
|
193 |
# Function to read in FASTA file |
|
|
194 |
def read_fasta_file(file_path): |
|
|
195 |
with open(file_path, 'r') as f: |
|
|
196 |
sequence = [] |
|
|
197 |
|
|
|
198 |
for line in f: |
|
|
199 |
line = line.strip() |
|
|
200 |
if not line.startswith('>'): |
|
|
201 |
sequence.append(line) |
|
|
202 |
|
|
|
203 |
protein_sequence = ''.join(sequence) |
|
|
204 |
return protein_sequence |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
|
|
|
208 |
if __name__ == "__main__": |
|
|
209 |
about() |
|
|
210 |
warnings.filterwarnings('ignore') |
|
|
211 |
|
|
|
212 |
if platform.system() == "Linux": |
|
|
213 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
214 |
|
|
|
215 |
#Sometimes, using Hugging Face may require a proxy. |
|
|
216 |
#os.environ["http_proxy"] = "http://your.proxy.server:port" |
|
|
217 |
#os.environ["https_proxy"] = "http://your.proxy.server:port" |
|
|
218 |
|
|
|
219 |
# Set up command line argument parsing |
|
|
220 |
parser = argparse.ArgumentParser() |
|
|
221 |
parser.add_argument('-p','--pro_seq', type=str, default=None, help='Input a protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.') |
|
|
222 |
parser.add_argument('-f','--fasta', type=str, default=None, help='Input a FASTA file. Default value is None. Only one of -p and -f should be specified.') |
|
|
223 |
parser.add_argument('-l','--ligand_prompt', type=str, default='', help='Input a ligand prompt. Default value is an empty string.') |
|
|
224 |
parser.add_argument('-e','--empty_input', action='store_true', default=False, help='Enable directly generate mode.') |
|
|
225 |
parser.add_argument('-n','--number',type=int, default=100, help='At least how many molecules will be generated. Default value is 100.') |
|
|
226 |
parser.add_argument('-d','--device',type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.") |
|
|
227 |
parser.add_argument('-o','--output', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.") |
|
|
228 |
parser.add_argument('-b','--batch_size', type=int, default=16, help="How many molecules will be generated per batch. Try to reduce this value if you have low RAM. Default value is 16.") |
|
|
229 |
parser.add_argument('-t','--temperature', type=float, default=1.0, help="Adjusts the randomness of text generation; higher values produce more diverse outputs. Default value is 1.0.") |
|
|
230 |
parser.add_argument('--top_k', type=int, default=9, help='The number of highest probability tokens to consider for top-k sampling. Defaults to 9.') |
|
|
231 |
parser.add_argument('--top_p', type=float, default=0.9, help='The cumulative probability threshold (0.0 - 1.0) for top-p (nucleus) sampling. It defines the minimum subset of tokens to consider for random sampling. Defaults to 0.9.') |
|
|
232 |
parser.add_argument('--min_atoms', type=int, default=None, help='Minimum number of non-H atoms allowed for generation.') |
|
|
233 |
parser.add_argument('--max_atoms', type=int, default=35, help='Maximum number of non-H atoms allowed for generation. Default value is 35.') |
|
|
234 |
parser.add_argument('--no_limit', action='store_true', default=False, help='Disable the default max atoms limit.') |
|
|
235 |
|
|
|
236 |
|
|
|
237 |
args = parser.parse_args() |
|
|
238 |
protein_seq = args.pro_seq |
|
|
239 |
fasta_file = args.fasta |
|
|
240 |
ligand_prompt = args.ligand_prompt |
|
|
241 |
directly_gen = args.empty_input |
|
|
242 |
num_generated = args.number |
|
|
243 |
device = args.device |
|
|
244 |
output_path = args.output |
|
|
245 |
batch_generated_size = args.batch_size |
|
|
246 |
temperature_value = args.temperature |
|
|
247 |
top_k = args.top_k |
|
|
248 |
top_p = args.top_p |
|
|
249 |
min_atoms = args.min_atoms |
|
|
250 |
max_atoms = args.max_atoms |
|
|
251 |
|
|
|
252 |
if args.no_limit: |
|
|
253 |
max_atoms = None |
|
|
254 |
|
|
|
255 |
if (args.min_atoms is not None) and (args.max_atoms is not None) and (args.min_atoms > args.max_atoms): |
|
|
256 |
raise ValueError("Error: min_atoms cannot be greater than max_atoms.") |
|
|
257 |
|
|
|
258 |
if args.ligand_prompt: |
|
|
259 |
args.max_atoms = None |
|
|
260 |
args.min_atoms = None |
|
|
261 |
print("Note: --ligand_prompt is specified. --max_atoms and --min_atoms settings will be ignored.") |
|
|
262 |
|
|
|
263 |
logging.basicConfig(level=logging.CRITICAL) |
|
|
264 |
openbabel.obErrorLog.StopLogging() |
|
|
265 |
os.makedirs(output_path, exist_ok=True) |
|
|
266 |
# Check if the input is either a protein amino acid sequence or a FASTA file, but not both |
|
|
267 |
if directly_gen: |
|
|
268 |
print("Now in directly generate mode.") |
|
|
269 |
prompt = "<|startoftext|><P>" |
|
|
270 |
print(prompt) |
|
|
271 |
else: |
|
|
272 |
if (not protein_seq) and (not fasta_file): |
|
|
273 |
print("Error: Input is empty.") |
|
|
274 |
sys.exit(1) |
|
|
275 |
if protein_seq and fasta_file: |
|
|
276 |
print("Error: The input should be either a protein amino acid sequence or a FASTA file, but not both.") |
|
|
277 |
sys.exit(1) |
|
|
278 |
if fasta_file: |
|
|
279 |
protein_seq = read_fasta_file(fasta_file) |
|
|
280 |
# Generate a prompt for the model |
|
|
281 |
p_prompt = "<|startoftext|><P>" + protein_seq + "<L>" |
|
|
282 |
l_prompt = "" + ligand_prompt |
|
|
283 |
prompt = p_prompt + l_prompt |
|
|
284 |
print(prompt) |
|
|
285 |
|
|
|
286 |
|
|
|
287 |
# Load the tokenizer and the model |
|
|
288 |
tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt') |
|
|
289 |
model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt") |
|
|
290 |
|
|
|
291 |
|
|
|
292 |
model.eval() |
|
|
293 |
device = torch.device(device) |
|
|
294 |
model.to(device) |
|
|
295 |
|
|
|
296 |
# Create a LigandPostprocessor object |
|
|
297 |
ligand_post_processor = LigandPostprocessor(output_path) |
|
|
298 |
|
|
|
299 |
# Generate molecules |
|
|
300 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) |
|
|
301 |
generated = generated.to(device) |
|
|
302 |
|
|
|
303 |
batch_number = 0 |
|
|
304 |
|
|
|
305 |
directly_gen_protein_list = [] |
|
|
306 |
directly_gen_ligand_list = [] |
|
|
307 |
|
|
|
308 |
|
|
|
309 |
attention_mask = generated.ne(tokenizer.pad_token_id).float() |
|
|
310 |
while len(ligand_post_processor.hash_ligand_mapping) < num_generated: |
|
|
311 |
generate_ligand_list = [] |
|
|
312 |
batch_number += 1 |
|
|
313 |
print(f"=====Batch {batch_number}=====") |
|
|
314 |
print("Generating ligand SMILES ...") |
|
|
315 |
sample_outputs = model.generate( |
|
|
316 |
generated, |
|
|
317 |
do_sample=True, |
|
|
318 |
top_k=top_k, |
|
|
319 |
max_length=1024, |
|
|
320 |
top_p=top_p, |
|
|
321 |
temperature=temperature_value, |
|
|
322 |
num_return_sequences=batch_generated_size, |
|
|
323 |
attention_mask=attention_mask, |
|
|
324 |
pad_token_id = tokenizer.eos_token_id |
|
|
325 |
) |
|
|
326 |
for sample_output in sample_outputs: |
|
|
327 |
generate_ligand = tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1] |
|
|
328 |
generate_ligand_list.append(generate_ligand) |
|
|
329 |
if directly_gen: |
|
|
330 |
directly_gen_protein_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[0]) |
|
|
331 |
directly_gen_ligand_list.append(generate_ligand) |
|
|
332 |
torch.cuda.empty_cache() |
|
|
333 |
ligand_post_processor.to_sdf(generate_ligand_list) |
|
|
334 |
ligand_post_processor.delete_empty_files() |
|
|
335 |
ligand_post_processor.check_sdf() |
|
|
336 |
|
|
|
337 |
if directly_gen: |
|
|
338 |
arr = np.array([directly_gen_protein_list, directly_gen_ligand_list]) |
|
|
339 |
processed_ligand_list = ligand_post_processor.hash_ligand_mapping.values() |
|
|
340 |
with open(os.path.join(output_path, 'generate_directly.csv'), 'w', newline='') as f: |
|
|
341 |
writer = csv.writer(f) |
|
|
342 |
for index in range(arr.shape[1]): |
|
|
343 |
protein, ligand = arr[0, index], arr[1, index] |
|
|
344 |
if ligand in processed_ligand_list: |
|
|
345 |
writer.writerow([protein, ligand]) |
|
|
346 |
|
|
|
347 |
print("Saving mapping file ...") |
|
|
348 |
ligand_post_processor.save_mapping() |
|
|
349 |
print(f"{len(ligand_post_processor.hash_ligand_mapping)} molecules successfully generated!") |
|
|
350 |
|
|
|
351 |
print("Ligand Energy Minimization") |
|
|
352 |
result = subprocess.run(['python', 'druggpt_min_multi.py', '-d', output_path]) |