Diff of /druggpt_min_multi.py [000000] .. [a621b4]

Switch to unified view

a b/druggpt_min_multi.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Sun Jul 23 08:12:43 2023
4
5
@author: Sen
6
"""
7
#%%
8
import os 
9
import argparse
10
import shutil
11
import logging
12
from openbabel import openbabel
13
parser = argparse.ArgumentParser()
14
parser.add_argument('-d', type=str, default=None, help='Input the dirpath')
15
args = parser.parse_args()
16
17
dirpath = args.d
18
if dirpath[-1] != '/' :
19
    dirpath = dirpath + '/'
20
#%%
21
def create_directory(dir_name):
22
    # 使用os.path.exists()检查目录是否存在
23
    if not os.path.exists(dir_name):
24
        # 如果目录不存在,使用os.makedirs()创建它
25
        os.makedirs(dir_name)
26
27
input_dirpath = dirpath  #文件夹以/结尾
28
dir_name = os.path.basename(os.path.dirname(input_dirpath))
29
dir_path = os.path.dirname(os.path.dirname(input_dirpath))
30
output_dirpath = os.path.join(dir_path,dir_name+'_min')
31
create_directory(output_dirpath)
32
logging.basicConfig(level=logging.CRITICAL)
33
openbabel.obErrorLog.StopLogging()
34
#%%
35
def sdf_min(input_sdf, output_sdf):
36
    # 创建一个分子对象
37
    mol = openbabel.OBMol()
38
39
    # 创建转换器,用于文件读写
40
    conv = openbabel.OBConversion()
41
    conv.SetInAndOutFormats("sdf", "sdf")
42
43
    # 从SDF文件中读取分子
44
    conv.ReadFile(mol, input_sdf)
45
46
    # 创建力场对象,使用MMFF94力场
47
    forcefield = openbabel.OBForceField.FindForceField("MMFF94")
48
49
    # 为分子设置力场
50
    success = forcefield.Setup(mol)
51
    if not success:
52
        raise Exception("Error setting up force field")
53
54
    # 进行能量最小化
55
    forcefield.SteepestDescent(10000)  # 原来是5000步最速下降法
56
    forcefield.GetCoordinates(mol)  # 将能量最小化后的坐标保存到分子对象
57
58
    # 将能量最小化后的分子写入到SDF文件
59
    conv.WriteFile(mol, output_sdf)
60
    
61
#%% 
62
from tqdm import tqdm
63
from concurrent.futures import ProcessPoolExecutor
64
import multiprocessing
65
66
def handle_file(filename):
67
    if '.sdf' == filename[-4:]:
68
        input_sdf_file = os.path.join(input_dirpath, filename)
69
        output_sdf_file = os.path.join(output_dirpath, filename)
70
        try:
71
            sdf_min(input_sdf_file, output_sdf_file)
72
        except Exception as e:
73
            print(f"An error occurred while processing the file '{input_sdf_file}': {e}")
74
            try:
75
                os.remove(output_sdf_file)
76
                print(output_sdf_file + ' was successfully removed')
77
            except Exception:
78
                if os.path.exists(output_sdf_file):
79
                    print('please remove '+output_sdf_file)
80
    else:
81
        src_file = os.path.join(input_dirpath, filename)
82
        dst_file = os.path.join(output_dirpath, filename)
83
        shutil.copy(src_file, dst_file)
84
85
def main():
86
    file_list = os.listdir(input_dirpath)
87
    # 获取系统的 CPU 核心数
88
    num_cores = multiprocessing.cpu_count()
89
90
    # 创建一个进程池
91
    with ProcessPoolExecutor(max_workers=num_cores) as executor:
92
        # 使用 tqdm 提供进度条功能
93
        list(tqdm(executor.map(handle_file, file_list), total=len(file_list)))
94
95
96
97
#%%
98
import pandas as pd
99
import os
100
import hashlib
101
class dir_check():
102
    def __init__(self,dirpath):
103
        self.dirpath = dirpath
104
        self.mapping_file = os.path.join(self.dirpath, 'hash_ligand_mapping.csv')
105
        self.mapping_data = pd.read_csv(self.mapping_file, header=None)
106
        self.smiles_list = self.mapping_data.iloc[:, 1].tolist()
107
        self.hash_list = self.mapping_data.iloc[:, 0].tolist()
108
        self.filename_list = os.listdir(self.dirpath)
109
        self.sdf_filename_list = [x for x in self.filename_list if x[-4:] == '.sdf']
110
        
111
    def mapping_file_check(self):
112
        for i in range(len(self.smiles_list)):
113
            if hashlib.sha1(self.smiles_list[i].encode()).hexdigest() != self.hash_list[i]:
114
                print('error in mapping file')
115
        print('mapping file check completed')
116
        
117
    def dir_file_check(self):
118
        filename_hash_list = [x[:-4] for x in self.sdf_filename_list]
119
        for filename_hash in filename_hash_list:
120
            if filename_hash not in self.hash_list:
121
                filename = filename_hash+'.sdf'
122
                print(filename + ' not in mapping file')
123
                os.remove(os.path.join(self.dirpath,filename))
124
                print('remove ' + os.path.join(self.dirpath,filename))
125
#%%      
126
127
if __name__ == '__main__':
128
    main()
129
  
130
    check = dir_check(output_dirpath)
131
    check.mapping_file_check()
132
    check.dir_file_check()   
133
    
134
    
135
    
136
    
137
    
138
    
139
    
140
    
141
    
142
    
143
    
144
    
145