--- a +++ b/druggpt_min_multi.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +""" +Created on Sun Jul 23 08:12:43 2023 + +@author: Sen +""" +#%% +import os +import argparse +import shutil +import logging +from openbabel import openbabel +parser = argparse.ArgumentParser() +parser.add_argument('-d', type=str, default=None, help='Input the dirpath') +args = parser.parse_args() + +dirpath = args.d +if dirpath[-1] != '/' : + dirpath = dirpath + '/' +#%% +def create_directory(dir_name): + # 使用os.path.exists()检查目录是否存在 + if not os.path.exists(dir_name): + # 如果目录不存在,使用os.makedirs()创建它 + os.makedirs(dir_name) + +input_dirpath = dirpath #文件夹以/结尾 +dir_name = os.path.basename(os.path.dirname(input_dirpath)) +dir_path = os.path.dirname(os.path.dirname(input_dirpath)) +output_dirpath = os.path.join(dir_path,dir_name+'_min') +create_directory(output_dirpath) +logging.basicConfig(level=logging.CRITICAL) +openbabel.obErrorLog.StopLogging() +#%% +def sdf_min(input_sdf, output_sdf): + # 创建一个分子对象 + mol = openbabel.OBMol() + + # 创建转换器,用于文件读写 + conv = openbabel.OBConversion() + conv.SetInAndOutFormats("sdf", "sdf") + + # 从SDF文件中读取分子 + conv.ReadFile(mol, input_sdf) + + # 创建力场对象,使用MMFF94力场 + forcefield = openbabel.OBForceField.FindForceField("MMFF94") + + # 为分子设置力场 + success = forcefield.Setup(mol) + if not success: + raise Exception("Error setting up force field") + + # 进行能量最小化 + forcefield.SteepestDescent(10000) # 原来是5000步最速下降法 + forcefield.GetCoordinates(mol) # 将能量最小化后的坐标保存到分子对象 + + # 将能量最小化后的分子写入到SDF文件 + conv.WriteFile(mol, output_sdf) + +#%% +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor +import multiprocessing + +def handle_file(filename): + if '.sdf' == filename[-4:]: + input_sdf_file = os.path.join(input_dirpath, filename) + output_sdf_file = os.path.join(output_dirpath, filename) + try: + sdf_min(input_sdf_file, output_sdf_file) + except Exception as e: + print(f"An error occurred while processing the file '{input_sdf_file}': {e}") + try: + os.remove(output_sdf_file) + print(output_sdf_file + ' was successfully removed') + except Exception: + if os.path.exists(output_sdf_file): + print('please remove '+output_sdf_file) + else: + src_file = os.path.join(input_dirpath, filename) + dst_file = os.path.join(output_dirpath, filename) + shutil.copy(src_file, dst_file) + +def main(): + file_list = os.listdir(input_dirpath) + # 获取系统的 CPU 核心数 + num_cores = multiprocessing.cpu_count() + + # 创建一个进程池 + with ProcessPoolExecutor(max_workers=num_cores) as executor: + # 使用 tqdm 提供进度条功能 + list(tqdm(executor.map(handle_file, file_list), total=len(file_list))) + + + +#%% +import pandas as pd +import os +import hashlib +class dir_check(): + def __init__(self,dirpath): + self.dirpath = dirpath + self.mapping_file = os.path.join(self.dirpath, 'hash_ligand_mapping.csv') + self.mapping_data = pd.read_csv(self.mapping_file, header=None) + self.smiles_list = self.mapping_data.iloc[:, 1].tolist() + self.hash_list = self.mapping_data.iloc[:, 0].tolist() + self.filename_list = os.listdir(self.dirpath) + self.sdf_filename_list = [x for x in self.filename_list if x[-4:] == '.sdf'] + + def mapping_file_check(self): + for i in range(len(self.smiles_list)): + if hashlib.sha1(self.smiles_list[i].encode()).hexdigest() != self.hash_list[i]: + print('error in mapping file') + print('mapping file check completed') + + def dir_file_check(self): + filename_hash_list = [x[:-4] for x in self.sdf_filename_list] + for filename_hash in filename_hash_list: + if filename_hash not in self.hash_list: + filename = filename_hash+'.sdf' + print(filename + ' not in mapping file') + os.remove(os.path.join(self.dirpath,filename)) + print('remove ' + os.path.join(self.dirpath,filename)) +#%% + +if __name__ == '__main__': + main() + + check = dir_check(output_dirpath) + check.mapping_file_check() + check.dir_file_check() + + + + + + + + + + + + + \ No newline at end of file