Diff of /HomoAug/run_HomoAug.py [000000] .. [b40915]

Switch to unified view

a b/HomoAug/run_HomoAug.py
1
# Copyright © 2023 Institute for AI Industry Research (AIR), Tsinghua University.
2
# License: GNU GPLv3. [See details in LICENSE]
3
4
import argparse
5
import os
6
import glob
7
import Bio.PDB
8
import ray
9
import random
10
import json
11
import mmap
12
import time
13
import sys
14
import traceback
15
from ray.util.queue import Queue
16
import numpy as np
17
18
sys.path.append(".")
19
from utils.misc import execute
20
from utils.ray_tools import ProgressBar
21
from tqdm import tqdm
22
import pathlib
23
import subprocess
24
import shutil
25
26
27
class JackhmmerRunner:
28
    def __init__(self, database_dir, task_ids, seq_dir, output_dir, n_thread):
29
        self.N_CPU_PER_THREAD = 1
30
        self.n_thread = n_thread
31
        self.database_dir = database_dir
32
        self.task_ids = task_ids
33
        self.seq_dir = seq_dir
34
        self.output_dir = output_dir
35
        if output_dir.endswith("/"):
36
            output_dir = output_dir[:-1]
37
        if not os.path.exists(output_dir):
38
            os.mkdir(output_dir)
39
        self.sto_dir = output_dir + "_sto"
40
        if not os.path.exists(self.sto_dir):
41
            os.mkdir(self.sto_dir)
42
        self.a2m_dir = output_dir + "_a2m"
43
        if not os.path.exists(self.a2m_dir):
44
            os.mkdir(self.a2m_dir)
45
46
    def split_list(_list, n):
47
        chunk_size = (len(_list) - 1) // n + 1
48
        chunks = [_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)]
49
        return chunks
50
51
    @ray.remote(num_cpus=1)
52
    def process_jobs(self, id, jobs_queue, actor):
53
        print("start process", id)
54
        while not jobs_queue.empty():
55
            job = jobs_queue.get()
56
            try:
57
                self.execute_one_job(job)
58
59
            except:
60
                print(f"failed: {job}")
61
                traceback.print_exception(*sys.exc_info())
62
            try:
63
                actor.update.remote(1)
64
            except:
65
                pass
66
        return 1
67
68
    def execute_one_job(self, job):
69
        seq_file_Path = pathlib.Path(self.seq_dir) / job[:4] / (job + ".fasta")
70
        output_file_Path = pathlib.Path(self.sto_dir) / (job + ".sto")
71
        execute(
72
            f"jackhmmer"
73
            f" --cpu {self.N_CPU_PER_THREAD}"
74
            f" -A {output_file_Path}"
75
            f" -o /dev/null"
76
            f" -E 0.001"
77
            f" -N 3"
78
            f" {str(seq_file_Path)}"
79
            f" {self.database_dir}"
80
        )
81
        # with open(output_file_Path,"w") as f:
82
        #     f.write("")
83
        # # sleep 1 s
84
        # time.sleep(1)
85
        return 1
86
87
    def change_sto_to_fasta(self):
88
        sto_files = glob.glob(self.sto_dir + "/*.sto")
89
        for sto_path in tqdm(sto_files):
90
            pdb_id = sto_path.split("/")[-1].split(".")[0]
91
            # print(pdb_id)
92
            a2m_path = self.a2m_dir + f"/{pdb_id}.a2m"
93
            execute(
94
                f"esl-reformat --informat stockholm"
95
                f" -o {str(a2m_path)} a2m"
96
                f" {str(sto_path)}"
97
            )
98
            fasta_path = self.output_dir + f"/{pdb_id}.fasta"
99
            output = ""
100
            with open(a2m_path) as f:
101
                for line in f:
102
                    line = line.strip()
103
                    if line[0] != ">":
104
                        output += line
105
                    else:
106
                        output += "\n" + line + "\n"
107
            output = output.strip()
108
            output = output.split("\n")[:-2]
109
            output = "\n".join(output)
110
            with open(fasta_path, "w") as f:
111
                f.write(output)
112
113
    def run_jackhmmer(self):
114
        all_jobs = []
115
        with open(self.task_ids, "r") as f:
116
            data = f.readlines()
117
            for line in data:
118
                job = line.strip()
119
                job = job[:4].upper()
120
                all_jobs.append(job)
121
        print("all jobs:", len(all_jobs))
122
        uncompleted_jobs = all_jobs
123
        # completed?
124
        # uncompleted_jobs=[]
125
        # for job in all_jobs:
126
        #     if (not is_complete(job)):
127
        #         uncompleted_jobs.append(job)
128
        # print("uncompleted jobs:",len(uncompleted_jobs))
129
        ray.init()
130
        job_queue = Queue()
131
        for job in tqdm(uncompleted_jobs):
132
            job_queue.put(job)
133
        print("job queue size:", job_queue.qsize())
134
        pb = ProgressBar(len(all_jobs))
135
        actor = pb.actor
136
        print("actor:", actor)
137
        job_id_list = []
138
        self.n_thread = min(self.n_thread, len(uncompleted_jobs))
139
        for i in range(self.n_thread):
140
            job_id_list.append(self.process_jobs.remote(self, i, job_queue, actor))
141
        pb.print_until_done()
142
        result = ray.get(job_id_list)
143
        print("Run homo search done!")
144
        ray.shutdown()
145
146
        # remove tmp folders
147
        shutil.rmtree(self.a2m_dir)
148
        shutil.rmtree(self.sto_dir)
149
150
        # change sto to fasta
151
        self.change_sto_to_fasta()
152
153
154
class LigandPocketExtractor:
155
    def __init__(self, id_file, homoaug_dir, n_thread):
156
        self.id_file = id_file
157
        self.homoaug_dir = homoaug_dir
158
        self.n_thread = n_thread
159
        self.read_ligand_name_chain_name()
160
161
    def read_ligand_name_chain_name(self):
162
        # read ligand_name
163
        ligand_name = {}
164
        chain_name = {}
165
        with open(self.id_file, "r") as f:
166
            for line in f.readlines():
167
                line = line.strip()
168
                id = line[:4].upper()
169
                ligand_name[id] = line.split("_")[-1]
170
                chain_name[id] = line.split("_")[1]
171
        # print("ligand_name:",ligand_name)
172
        self.ligand_name = ligand_name
173
        self.chain_name = chain_name
174
175
    def execute_one_job(self, job):
176
        id = job
177
        # print ('############################')
178
        # print (id)
179
        # read original pdb
180
        pdb_file = self.homoaug_dir + "/" + id + "/" + id + "_protein.pdb"
181
        if not os.path.exists(pdb_file):
182
            print("no pdb file for id:", id)
183
            return 1
184
        pdb_parser = Bio.PDB.PDBParser(QUIET=True)
185
        structure = pdb_parser.get_structure(id, pdb_file)
186
        model = structure[0]
187
        chain = model[self.chain_name[id]]
188
        # get pocket_chain to a new pdb
189
        pocket_chain = Bio.PDB.Chain.Chain(self.chain_name[id])
190
        for residue in chain:
191
            if residue.id[0] == " ":
192
                pocket_chain.add(residue)
193
        io = Bio.PDB.PDBIO()
194
        io.set_structure(pocket_chain)
195
        io.save(self.homoaug_dir + "/" + id + "/" + id + "_pocket_chain.pdb")
196
197
        # save ligand to a new pdb
198
        ligand_chain = Bio.PDB.Chain.Chain(self.chain_name[id])
199
        ligand_found = False
200
        for residue in chain:
201
            if residue.resname == self.ligand_name[id]:
202
                ligand_chain.add(residue)
203
                ligand_found = True
204
        if ligand_found:
205
            io = Bio.PDB.PDBIO()
206
            io.set_structure(ligand_chain)
207
            io.save(self.homoaug_dir + "/" + id + "/" + id + "_ligand.pdb")
208
        else:
209
            print("ligand not found:", id)
210
            return 1
211
212
        # remove ligand in the id.pocket.pdb
213
        pdb_file = self.homoaug_dir + "/" + id + "/" + id + "_pocket.pdb"
214
        pdb_parser = Bio.PDB.PDBParser(QUIET=True)
215
        structure = pdb_parser.get_structure(id, pdb_file)
216
        model = structure[0]
217
        chain = model["R"]
218
        # get pocket to a new pdb
219
        pocket = Bio.PDB.Chain.Chain("A")
220
        for residue in chain:
221
            if residue.id[0] == " ":
222
                pocket.add(residue)
223
        # write pocket pdb
224
        io = Bio.PDB.PDBIO()
225
        io.set_structure(pocket)
226
        io.save(self.homoaug_dir + "/" + id + "/" + id + "_pocket.pdb")
227
228
    @ray.remote(num_cpus=1)
229
    def process_jobs(self, id, jobs_queue, actor):
230
        print("start process", id)
231
        while not jobs_queue.empty():
232
            job = jobs_queue.get()
233
            try:
234
                self.execute_one_job(job)
235
            except:
236
                print(f"failed: {job}")
237
                traceback.print_exception(*sys.exc_info())
238
            try:
239
                actor.update.remote(1)
240
            except:
241
                pass
242
        return 1
243
244
    def run(self):
245
        all_jobs = glob.glob(self.homoaug_dir + "/*")
246
        all_jobs = [x.split("/")[-1] for x in all_jobs]
247
        uncompleted_jobs = all_jobs
248
        job_queue = Queue()
249
        for job in tqdm(uncompleted_jobs):
250
            job_queue.put(job)
251
        print("job queue size:", job_queue.qsize())
252
        pb = ProgressBar(len(all_jobs))
253
        actor = pb.actor
254
        job_id_list = []
255
        self.n_thread = min(self.n_thread, len(uncompleted_jobs))
256
        for i in range(self.n_thread):
257
            job_id_list.append(self.process_jobs.remote(self, i, job_queue, actor))
258
        pb.print_until_done()
259
        result = ray.get(job_id_list)
260
        ray.shutdown()
261
        print("Done!")
262
263
264
class PocketPositionExtractor:
265
    def __init__(self, homoaug_dir, n_thread):
266
        self.homoaug_dir = homoaug_dir
267
        self.n_thread = n_thread
268
        self.aa_3_to_1 = {
269
            "CYS": "C",
270
            "ASP": "D",
271
            "SER": "S",
272
            "GLN": "Q",
273
            "LYS": "K",
274
            "ILE": "I",
275
            "PRO": "P",
276
            "THR": "T",
277
            "PHE": "F",
278
            "ASN": "N",
279
            "GLY": "G",
280
            "HIS": "H",
281
            "LEU": "L",
282
            "ARG": "R",
283
            "TRP": "W",
284
            "ALA": "A",
285
            "VAL": "V",
286
            "GLU": "E",
287
            "TYR": "Y",
288
            "MET": "M",
289
            "MSE": "M",
290
            "CME": "C",
291
            "CSO": "C",
292
            "UNK": "X",
293
        }
294
295
    def execute_one_job(self, job):
296
        id = job
297
        # print ('############################')
298
        # print (id)
299
300
        # read pocket_pdb
301
        pocket_pdb_file = self.homoaug_dir + "/" + id + "/" + id + "_pocket.pdb"
302
        pocket_pdb_structure = Bio.PDB.PDBParser().get_structure(id, pocket_pdb_file)
303
        model = pocket_pdb_structure[0]
304
        for chain in model:
305
            pocket_chain_id = chain.id
306
            break
307
        chain = model[pocket_chain_id]
308
309
        # get the pocket atom coordinates
310
        pocket_atom_coordinates = set()
311
        for residue in chain:
312
            if residue.id[0] != " ":
313
                continue
314
            for atom in residue:
315
                pocket_atom_coordinates.add(tuple(atom.get_coord()))
316
        # print("pocket_atom_coordinates:",pocket_atom_coordinates)
317
318
        sequence = ""
319
        pocket_chain_pdb_file = (
320
            self.homoaug_dir + "/" + id + "/" + id + "_pocket_chain.pdb"
321
        )
322
        # get the position of the pocket residues in the pocket chain
323
        pocket_chain_structure = Bio.PDB.PDBParser().get_structure(
324
            id, pocket_chain_pdb_file
325
        )
326
        model = pocket_chain_structure[0]
327
        for chain in model:
328
            pocket_chain_id = chain.id
329
            break
330
        chain_num = len(list(model.get_chains()))
331
        if chain_num > 1:
332
            print("error: more than 1 chain in pocket_pdb_file:", id)
333
            return 1
334
        chain = model[pocket_chain_id]
335
        for residue in chain:
336
            if residue.id[0] != " ":
337
                continue
338
            in_pocket = False
339
            for atom in residue:
340
                if tuple(atom.get_coord()) in pocket_atom_coordinates:
341
                    in_pocket = True
342
                    break
343
            if in_pocket:
344
                sequence += self.aa_3_to_1[residue.resname]
345
            else:
346
                sequence += "-"
347
348
        # save
349
        sequence_position_file = (
350
            self.homoaug_dir + "/" + id + "/" + id + "_pocket_position.txt"
351
        )
352
        with open(sequence_position_file, "w") as f:
353
            f.write(sequence)
354
355
    @ray.remote(num_cpus=1)
356
    def process_jobs(self, id, jobs_queue, actor):
357
        print("start process", id)
358
        while not jobs_queue.empty():
359
            job = jobs_queue.get()
360
            try:
361
                self.execute_one_job(job)
362
            except:
363
                print(f"failed: {job}")
364
                traceback.print_exception(*sys.exc_info())
365
            try:
366
                actor.update.remote(1)
367
            except:
368
                pass
369
        return 1
370
371
    def run(self):
372
        all_jobs = []
373
        for id in os.listdir(self.homoaug_dir):
374
            if os.path.isfile(self.homoaug_dir + "/" + id + "/" + id + "_ligand.pdb"):
375
                if os.path.isfile(
376
                    self.homoaug_dir + "/" + id + "/" + id + "_pocket.pdb"
377
                ):
378
                    all_jobs.append(id)
379
        uncompleted_jobs = all_jobs
380
        job_queue = Queue()
381
        for job in tqdm(uncompleted_jobs):
382
            job_queue.put(job)
383
        print("job queue size:", job_queue.qsize())
384
        pb = ProgressBar(len(all_jobs))
385
        actor = pb.actor
386
        job_id_list = []
387
        self.n_thread = min(self.n_thread, len(uncompleted_jobs))
388
        for i in range(self.n_thread):
389
            job_id_list.append(self.process_jobs.remote(self, i, job_queue, actor))
390
        pb.print_until_done()
391
        result = ray.get(job_id_list)
392
        ray.shutdown()
393
        print("Done!")
394
395
396
class TMalignRunner:
397
    def __init__(
398
        self,
399
        max_extend_num,
400
        homoaug_dir,
401
        MSA_dir,
402
        AF2DB_dir,
403
        TMscore_threshold,
404
        Match_rate_threshold,
405
        n_thread,
406
    ):
407
        self.n_thread = n_thread
408
        self.max_extend_num = max_extend_num
409
        self.homoaug_dir = homoaug_dir
410
        self.MSA_dir = MSA_dir
411
        self.AF2DB_dir = AF2DB_dir
412
        self.TMscore_threshold = TMscore_threshold
413
        self.Match_rate_threshold = Match_rate_threshold
414
415
    def _remove_gap_of_primary_sequence(self, primary_sequence, candidate_sequence):
416
        assert len(primary_sequence) == len(candidate_sequence)
417
        primary_sequence_without_gap = ""
418
        candidate_sequence_without_gap = ""
419
        for i in range(len(primary_sequence)):
420
            if primary_sequence[i] != "-":
421
                primary_sequence_without_gap += primary_sequence[i]
422
                candidate_sequence_without_gap += candidate_sequence[i]
423
        return primary_sequence_without_gap, candidate_sequence_without_gap
424
425
    def _calc_match_rate(self, pocket_position, Aligned_seq):
426
        total_cnt = 0
427
        match_cnt = 0
428
        for i in range(len(Aligned_seq)):
429
            if pocket_position[i] != "-":
430
                total_cnt += 1
431
                if pocket_position[i] == Aligned_seq[i]:
432
                    match_cnt += 1
433
        return match_cnt / total_cnt
434
435
    def _get_rotate_matrix(self, rotate_matrix_file):
436
        with open(rotate_matrix_file, "r") as f:
437
            data = f.readlines()
438
        u = []
439
        t = []
440
        for i in range(2, 5):
441
            line = data[i].split(" ")
442
            line_float = [float(x) for x in line if x != ""]
443
            t.append(line_float[1])
444
            u.append(line_float[2:])
445
        u = np.array(u)
446
        t = np.array(t)
447
        return u, t
448
449
    def _read_ligand_pdb(self, ligand_pdb_file):
450
        parser = Bio.PDB.PDBParser()
451
        structure = parser.get_structure("ligand", ligand_pdb_file)
452
        ligand_coords = []
453
        for model in structure:
454
            for chain in model:
455
                for residue in chain:
456
                    for atom in residue:
457
                        ligand_coords.append(atom.get_coord())
458
        return ligand_coords
459
460
    @ray.remote(num_cpus=1)
461
    def process_jobs(self, id, jobs_queue, actor):
462
        print("start process", id)
463
        while not jobs_queue.empty():
464
            job = jobs_queue.get()
465
            try:
466
                self.execute_one_job(job)
467
468
            except:
469
                print(f"failed: {job}")
470
                traceback.print_exception(*sys.exc_info())
471
            try:
472
                actor.update.remote(1)
473
            except:
474
                pass
475
        return 1
476
477
    def execute_one_job(self, id):
478
        # print("#######################")
479
        print(id)
480
        # get the sequence from pdb
481
        fasta_dir = self.homoaug_dir + "/" + id + "/" + id + ".fasta"
482
        with open(fasta_dir) as f:
483
            fasta = f.readlines()
484
485
        # read the pocket position
486
        pocket_position_file = (
487
            self.homoaug_dir + "/" + id + "/" + id + "_pocket_position.txt"
488
        )
489
        if not os.path.exists(pocket_position_file):
490
            print("position_file not exist")
491
            return 1
492
        with open(pocket_position_file) as f:
493
            pocket_position = f.readline().strip()
494
495
        # create rotation matrix dir
496
        rotation_matrix_dir = self.homoaug_dir + "/" + id + "/" + "rotation_matrix/"
497
        if not os.path.exists(rotation_matrix_dir):
498
            os.makedirs(rotation_matrix_dir)
499
500
        # get the sequence from TMalign
501
        chain_pdb_file = self.homoaug_dir + "/" + id + "/" + id + "_pocket_chain.pdb"
502
        MSA_file = self.MSA_dir + f"/{id}" + ".fasta"
503
        if not os.path.exists(MSA_file):
504
            print("MSA_file not exist")
505
            return 1
506
        MSA_ids = []
507
        with open(MSA_file) as f:
508
            lines = f.readlines()
509
            for idx in range(0, len(lines), 2):
510
                MSA_ids.append(lines[idx].strip().split(" ")[-1])
511
512
        # create extend dir
513
        extend_dir = self.homoaug_dir + "/" + id + "/" + "extend"
514
        if not os.path.exists(extend_dir):
515
            os.mkdir(extend_dir)
516
517
        # get ligand
518
        ligand_file = self.homoaug_dir + "/" + id + "/" + id + "_ligand.pdb"
519
        ligand_coords = self._read_ligand_pdb(ligand_file)
520
521
        # TMalign
522
        cnt = len(glob.glob(extend_dir + "/*"))
523
        for MSA_id in list(MSA_ids):
524
            if cnt >= self.max_extend_num:
525
                break
526
            # calculate TMscore
527
            MSA_pdb_file = self.AF2DB_dir + f"/{MSA_id}.pdb"
528
            if not os.path.exists(MSA_pdb_file):
529
                continue
530
            rotation_matrix_file = rotation_matrix_dir + f"{MSA_id}.txt"
531
            if os.path.exists(rotation_matrix_file):
532
                continue
533
            out_bytes = subprocess.check_output(
534
                ["TMalign", MSA_pdb_file, chain_pdb_file, "-m", rotation_matrix_file]
535
            )
536
            out_text = out_bytes.decode("utf-8").strip().split("\n")
537
            TMscore1 = float(out_text[12].split(" ")[1])
538
            TMscore2 = float(out_text[13].split(" ")[1])
539
            (
540
                sequence_from_TMalign,
541
                MSA_aligned_sequence,
542
            ) = self._remove_gap_of_primary_sequence(out_text[19], out_text[17])
543
            TMalign_file = rotation_matrix_dir + f"{MSA_id}_TMscore.txt"
544
            with open(TMalign_file, "w") as f:
545
                f.write("TMscore normalized to chain_pdb:" + str(TMscore2) + "\n")
546
                f.write("TMscore normalized to MSA_pdb:" + str(TMscore1) + "\n")
547
                f.write("Aligned sequence : \n")
548
                f.write(sequence_from_TMalign + "\n")
549
                f.write(MSA_aligned_sequence + "\n")
550
            TMscore = TMscore2
551
552
            # calculate Match_score
553
            Aligned_seq = MSA_aligned_sequence
554
            Match_rate = self._calc_match_rate(pocket_position, Aligned_seq)
555
556
            # print("MSA_id:",MSA_id)
557
            # print("TMscore:",TMscore)
558
            # print("Match_rate:",Match_rate)
559
            if (
560
                TMscore >= self.TMscore_threshold
561
                and Match_rate >= self.Match_rate_threshold
562
            ):
563
                extend_instance_dir = extend_dir + "/" + MSA_id + "/"
564
                # if os.path.exists(extend_instance_dir):
565
                #     continue
566
                os.mkdir(extend_instance_dir)
567
568
                # read ori MSA pdb file
569
                MSA_pdb_file = self.AF2DB_dir + f"/{MSA_id}" + ".pdb"
570
                parser = Bio.PDB.PDBParser()
571
                structure = parser.get_structure(MSA_id, MSA_pdb_file)
572
                model = structure[0]
573
                for chain in model:
574
                    MSA_chain_id = chain.id
575
                    break
576
                MSA_chain = model[MSA_chain_id]
577
578
                # get rotate_matrix
579
                rotation_matrix_file = rotation_matrix_dir + f"{MSA_id}.txt"
580
                rotation_matrix = self._get_rotate_matrix(rotation_matrix_file)
581
582
                for residue in MSA_chain:
583
                    for atom in residue:
584
                        coord = atom.get_coord()
585
                        coord = np.array(coord)
586
                        new_coord = (
587
                            np.dot(rotation_matrix[0], coord) + rotation_matrix[1]
588
                        )
589
                        atom.set_coord(new_coord)
590
591
                # write new pdb file
592
                io = Bio.PDB.PDBIO()
593
                io.set_structure(structure)
594
                io.save(extend_instance_dir + f"{MSA_id}" + "_protein.pdb")
595
596
                # get pocket , which is in the 6A of ligand
597
                MSA_pocket_file = extend_instance_dir + f"{MSA_id}" + "_pocket.pdb"
598
                for residue in MSA_chain:
599
                    remove_atom_ids = []
600
                    for atom in residue:
601
                        # print("atom: ",atom.id)
602
                        coord = atom.get_coord()
603
                        f = 0
604
                        for ligand_coord in ligand_coords:
605
                            dis = np.linalg.norm(coord - ligand_coord)
606
                            if np.linalg.norm(coord - ligand_coord) <= 6:
607
                                f = 1
608
                                break
609
                        if f == 0:
610
                            remove_atom_ids.append(atom.id)
611
                    for atom_id in remove_atom_ids:
612
                        residue.detach_child(atom_id)
613
                io = Bio.PDB.PDBIO()
614
                io.set_structure(structure)
615
                io.save(MSA_pocket_file)
616
                cnt += 1
617
        print("finish: pdb_id:", id)
618
        return 1
619
620
    def run(self):
621
        all_jobs = []
622
        for id in os.listdir(self.homoaug_dir):
623
            if (
624
                os.path.isfile(self.homoaug_dir + "/" + id + "/" + id + "_ligand.pdb")
625
                and os.path.isfile(
626
                    self.homoaug_dir + "/" + id + "/" + id + "_pocket.pdb"
627
                )
628
                and os.path.isfile(
629
                    self.homoaug_dir + "/" + id + "/" + id + "_pocket_position.txt"
630
                )
631
            ):
632
                all_jobs.append(id)
633
        uncompleted_jobs = all_jobs
634
        job_queue = Queue()
635
        for job in tqdm(uncompleted_jobs):
636
            job_queue.put(job)
637
        print("job queue size:", job_queue.qsize())
638
        pb = ProgressBar(len(all_jobs))
639
        actor = pb.actor
640
        job_id_list = []
641
        self.n_thread = min(self.n_thread, len(uncompleted_jobs))
642
        for i in range(self.n_thread):
643
            job_id_list.append(self.process_jobs.remote(self, i, job_queue, actor))
644
        pb.print_until_done()
645
        result = ray.get(job_id_list)
646
        ray.shutdown()
647
        print("Done!")
648
649
650
class HomoAugRunner:
651
    def __init__(self, args):
652
        self.id_file = args.id_file
653
        self.homoaug_dir = args.homoaug_dir
654
        self.fasta_file = args.fasta_file
655
        self.protein_pdb_dir = args.protein_pdb_dir
656
        self.pocket_pdbs_dir = args.pocket_pdbs_dir
657
        self.jackhmmer_output_dir = args.jackhmmer_output_dir
658
        self.database_fasta_path = args.database_fasta_path
659
        self.max_extend_num = args.max_extend_num
660
        self.database_pdb_dir = args.database_pdb_dir
661
        self.TMscore_threshold = args.TMscore_threshold
662
        self.Match_rate_threshold = args.Match_rate_threshold
663
        self.n_thread = args.n_thread
664
665
    def read_dataset_fasta(self):
666
        protein_seq = {}
667
        with open(self.fasta_file) as f:
668
            fasta = f.readlines()
669
            for i in range(0, len(fasta), 2):
670
                id = fasta[i].strip()
671
                id = id[1:5].upper()
672
                seq = fasta[i + 1].strip()
673
                protein_seq[id] = seq
674
        return protein_seq
675
676
    def create_dir(self):
677
        # Create homoaug dir and subdirs
678
        # Dir format
679
        # homoaug_dir
680
        # └── id
681
        #     ├── id.fasta
682
        #     └── id_protein.pdb
683
        #     └── id_pocket.pdb
684
685
        protein_seq = self.read_dataset_fasta()
686
        if not os.path.exists(self.homoaug_dir):
687
            os.mkdir(self.homoaug_dir)
688
        with open(self.id_file, "r") as f:
689
            lines = f.readlines()
690
        for line in tqdm(lines):
691
            id = line.strip()
692
            id = id[:4].upper()
693
            if not os.path.exists(self.homoaug_dir + "/" + id):
694
                os.mkdir(self.homoaug_dir + "/" + id)
695
696
                # create fasta
697
                with open(self.homoaug_dir + "/" + id + "/" + id + ".fasta", "w") as f:
698
                    f.write(">" + id + "\n" + protein_seq[id] + "\n")
699
700
                # copy pdb
701
                cif_file = glob.glob(self.protein_pdb_dir + "/" + id + "*")
702
                if len(cif_file) != 1:
703
                    print("error : having more than 1 cif file", id)
704
                    continue
705
                cif_file = cif_file[0]
706
                # read cif
707
                parser = Bio.PDB.MMCIFParser(QUIET=True)
708
                structure = parser.get_structure(id, cif_file)
709
                # get number of models
710
                n_model = len(structure)
711
                if n_model != 1:
712
                    print("error : having more than 1 model", id)
713
                    continue
714
                # get number of chains
715
                n_chain = len(list(structure.get_chains()))
716
                if n_chain != 1:
717
                    print("error : having more than 1 chain", id)
718
                    continue
719
                # save pdb
720
                io = Bio.PDB.PDBIO()
721
                io.set_structure(structure)
722
                io.save(self.homoaug_dir + "/" + id + "/" + id + "_protein.pdb")
723
                # os.system("cp "+homoaug_dir+"/"+id+"/"+id+"_protein.pdb "+homoaug_dir+"/"+id+"/"+id+"_pocket_chain.pdb")
724
725
                # copy pocket
726
                pocket_file = glob.glob(self.pocket_pdbs_dir + "/" + id + "*")
727
                if len(pocket_file) != 1:
728
                    print("error : having more than 1 pocket file", id)
729
                    continue
730
                pocket_file = pocket_file[0]
731
                # copy to homoaug
732
                os.system(
733
                    "cp "
734
                    + pocket_file
735
                    + " "
736
                    + self.homoaug_dir
737
                    + "/"
738
                    + id
739
                    + "/"
740
                    + id
741
                    + "_pocket.pdb"
742
                )
743
744
    def run(self):
745
        self.create_dir()
746
747
        # Run jackhmmer
748
        print("# Start running jackhmmer")
749
        jackhmmer_runner = JackhmmerRunner(
750
            database_dir=self.database_fasta_path,
751
            task_ids=self.id_file,
752
            seq_dir=self.homoaug_dir,
753
            output_dir=self.jackhmmer_output_dir,
754
            n_thread=self.n_thread,
755
        )
756
        jackhmmer_runner.run_jackhmmer()
757
758
        print("# Start running ligand pocket extractor")
759
        ligand_pocket_extractor = LigandPocketExtractor(
760
            self.id_file, self.homoaug_dir, self.n_thread
761
        )
762
        ligand_pocket_extractor.run()
763
764
        print("# Start running pocket position extractor")
765
        pocket_position_extractor = PocketPositionExtractor(
766
            self.homoaug_dir, self.n_thread
767
        )
768
        pocket_position_extractor.run()
769
770
        print("# Start running TMalign")
771
        tmalign_runner = TMalignRunner(
772
            self.max_extend_num,
773
            self.homoaug_dir,
774
            self.jackhmmer_output_dir,
775
            self.database_pdb_dir,
776
            self.TMscore_threshold,
777
            self.Match_rate_threshold,
778
            self.n_thread,
779
        )
780
        tmalign_runner.run()
781
782
783
if __name__ == "__main__":
784
    parser = argparse.ArgumentParser()
785
    parser.add_argument("--id_file", type=str, default="/drug/BioLip/tmp.id")
786
    parser.add_argument("--homoaug_dir", type=str, default="/drug/BioLip/homoaug_new")
787
    parser.add_argument(
788
        "--fasta_file",
789
        type=str,
790
        default="/drug/BioLip/BioLiP_v2023-04-13_regularLigand.fasta",
791
    )
792
    parser.add_argument(
793
        "--protein_pdb_dir", type=str, default="/drug/BioLip/protein_pdb"
794
    )
795
    parser.add_argument(
796
        "--pocket_pdbs_dir", type=str, default="/drug/BioLip/pocket_pdb"
797
    )
798
    parser.add_argument(
799
        "--jackhmmer_output_dir", type=str, default="/drug/BioLip/pdbbind_MSA_fasta"
800
    )
801
    parser.add_argument(
802
        "--n_thread", type=int, default=10, help="number of threads for running"
803
    )
804
    parser.add_argument(
805
        "--database_fasta_path",
806
        type=str,
807
        default="/data/protein/AF2DB/AFDB_HC_50.fa",
808
        help="jackhmmer search database, in fasta format",
809
    )
810
    parser.add_argument(
811
        "--database_pdb_dir",
812
        type=str,
813
        default="/drug/AFDB_HC_50_PDB",
814
        help="homoaug search database, e.g. AF2DB",
815
    )
816
    parser.add_argument(
817
        "--max_extend_num",
818
        type=int,
819
        default=20,
820
        help="max number of extended pocket-ligand pairs for one real pocket-ligand pair",
821
    )
822
    parser.add_argument(
823
        "--TMscore_threshold",
824
        type=float,
825
        default=0.4,
826
        help="TMscore threshold for extending",
827
    )
828
    parser.add_argument(
829
        "--Match_rate_threshold",
830
        type=float,
831
        default=0.4,
832
        help="Match_rate threshold for extending",
833
    )
834
835
    args = parser.parse_args()
836
    homoaug_runner = HomoAugRunner(args)
837
    homoaug_runner.run()
838
    print("HomoAug Done!")