--- a
+++ b/drug_generator.py
@@ -0,0 +1,352 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Mon May  1 19:41:07 2023
+
+@author: Sen
+"""
+
+
+import os
+import sys
+import subprocess
+import hashlib
+import warnings
+import platform
+import csv
+import numpy as np
+from tqdm import tqdm
+import argparse
+import torch
+from transformers import AutoTokenizer, GPT2LMHeadModel
+import shutil
+
+from openbabel import openbabel
+import logging
+import time
+import subprocess
+import threading
+import os
+import signal
+import psutil
+
+import os
+
+class Command(object):
+    def __init__(self, cmd):
+        self.cmd = cmd
+        self.process = None
+
+    def run(self, timeout):
+        def target():
+            try:
+                if os.name == 'posix':  # Unix/Linux/Mac
+                    self.process = subprocess.Popen(self.cmd, shell=True, stderr=subprocess.DEVNULL,preexec_fn=os.setsid)
+                else:  # Windows
+                    self.process = subprocess.Popen(self.cmd, shell=True, stderr=subprocess.DEVNULL)
+                self.process.communicate()
+            except Exception:
+                pass
+
+        thread = threading.Thread(target=target)
+        thread.start()
+
+        thread.join(timeout)
+        if thread.is_alive():
+            if os.name == 'posix':  # Unix/Linux/Mac
+                os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
+            else:  # Windows
+                parent = psutil.Process(self.process.pid)
+                for child in parent.children(recursive=True):
+                    child.kill()
+                parent.kill()
+            thread.join()
+        return self.process.returncode if self.process else None
+
+
+class LigandPostprocessor:
+    def __init__(self, path):
+        self.hash_ligand_mapping = {}
+        self.output_path = path  # Output directory for SDF files
+        self.load_mapping()
+
+    def load_mapping(self):
+        mapping_file = os.path.join(output_path, 'hash_ligand_mapping.csv')
+        if os.path.exists(mapping_file):
+            print("Found existed mapping file, now reading ...")
+            with open(mapping_file, 'r') as f:
+                reader = csv.reader(f)
+                for row in reader:
+                    self.hash_ligand_mapping[row[0]] = row[1]
+
+    # Define a function to save the hash-ligand mapping to a file
+    def save_mapping(self):
+        mapping_file = os.path.join(output_path, 'hash_ligand_mapping.csv')
+        with open(mapping_file, 'w', newline='') as f:
+            writer = csv.writer(f)
+            for ligand_hash, ligand in self.hash_ligand_mapping.items():
+                writer.writerow([ligand_hash, ligand])
+
+    # Define a function to filter out empty SDF files
+    def filter_sdf(self, hash_ligand_mapping_per_batch):
+        print("Filtering sdf ...")
+        ligand_hash_list = list(hash_ligand_mapping_per_batch.keys())
+        mapping_per_match = hash_ligand_mapping_per_batch.copy()
+        for ligand_hash in tqdm(ligand_hash_list):
+            filepath = os.path.join(self.output_path, ligand_hash + '.sdf')            
+            if os.path.getsize(filepath) < 2*1024:  #2kb
+                try:
+                    os.remove(filepath)
+                    #mapping_per_match.pop(ligand_hash)
+                except Exception:
+                    print(filepath)
+                mapping_per_match.pop(ligand_hash)    
+        return mapping_per_match
+
+    # Define a function to generate SDF files from a list of ligand SMILES using OpenBabel
+    def to_sdf(self, ligand_list_per_batch):
+        print("Converting to sdf ...")
+        hash_ligand_mapping_per_batch = {}
+        for ligand in tqdm(ligand_list_per_batch):  
+            
+            obConversion = openbabel.OBConversion()
+            obConversion.SetInAndOutFormats("smi", "smi")
+            mol = openbabel.OBMol()
+            if not obConversion.ReadString(mol, ligand):
+                continue  # Skip invalid SMILES
+            
+            num_atoms = sum(1 for atom in openbabel.OBMolAtomIter(mol) if atom.GetAtomicNum() != 1)
+            if min_atoms is not None and num_atoms < min_atoms:
+                continue  # Skip molecules with too few non-hydrogen atoms
+            if max_atoms is not None and num_atoms > max_atoms:
+                continue  # Skip molecules with too many non-hydrogen atoms
+            
+            ligand_hash = hashlib.sha1(ligand.encode()).hexdigest()
+            if ligand_hash not in self.hash_ligand_mapping.keys():
+                filepath = os.path.join(self.output_path , ligand_hash + '.sdf')
+                
+                if platform.system() == "Windows":
+                    cmd = "obabel -:" + ligand + " -osdf -O " + filepath + " --gen3d --forcefield mmff94"
+                elif platform.system() == "Linux":
+                    obabel_path = shutil.which('obabel')
+                    cmd = f"{obabel_path} -:'{ligand}' -osdf -O '{filepath}' --gen3d --forcefield mmff94"
+                else:pass
+
+                try:
+                    command = Command(cmd)
+                    return_code = command.run(timeout=10)
+                    if return_code != 0:  # Check the return value
+                        #print(f"Command execution failed with return code: {return_code}")
+                        continue  # Skip the current iteration if the command execution failed
+                except Exception:
+                    time.sleep(1)
+                    continue
+                    
+                if os.path.exists(filepath):
+                    hash_ligand_mapping_per_batch[ligand_hash] = ligand  # Add the hash-ligand mapping to the dictionary
+        self.hash_ligand_mapping.update(self.filter_sdf(hash_ligand_mapping_per_batch))
+    
+    def delete_empty_files(self):
+    # 遍历指定目录及其子目录中的所有文件
+        for foldername, subfolders, filenames in os.walk(self.output_path):
+            for filename in filenames:
+                file_path = os.path.join(foldername, filename)
+                # 如果文件大小为0,则删除该文件
+                if os.path.getsize(file_path) < 2*1024:  #2kb
+                    try:
+                        os.remove(file_path)
+                        print(f'Deleted {file_path}')
+                    except Exception:
+                        pass 
+    
+    
+    def check_sdf(self):
+        file_list = os.listdir(self.output_path)
+        sdf_file_list = [x for x in file_list if x[-4:]=='sdf']
+        for filename in sdf_file_list:
+            hash_ = filename[:-4]
+            if hash_ not in self.hash_ligand_mapping.keys():
+                filepath = os.path.join(self.output_path,filename)
+                try:
+                    os.remove(filepath)
+                    print('remove ' + filepath)
+                except Exception:
+                    pass
+            else:pass    
+                
+               
+                
+    
+def about():
+    print("""
+  _____                    _____ _____ _______ 
+ |  __ \                  / ____|  __ \__   __|
+ | |  | |_ __ _   _  __ _| |  __| |__) | | |   
+ | |  | | '__| | | |/ _` | | |_ |  ___/  | |   
+ | |__| | |  | |_| | (_| | |__| | |      | |   
+ |_____/|_|   \__,_|\__, |\_____|_|      |_|   
+                     __/ |                     
+                    |___/                      
+ A generative drug design model based on GPT2
+    """)
+
+
+# Function to read in FASTA file
+def read_fasta_file(file_path):
+    with open(file_path, 'r') as f:
+        sequence = []
+
+        for line in f:
+            line = line.strip()
+            if not line.startswith('>'):
+                sequence.append(line)
+
+        protein_sequence = ''.join(sequence)
+    return protein_sequence
+
+
+                    
+if __name__ == "__main__":
+    about()
+    warnings.filterwarnings('ignore')
+    
+    if platform.system() == "Linux":
+        os.environ["TOKENIZERS_PARALLELISM"] = "false"
+    
+    #Sometimes, using Hugging Face may require a proxy.
+    #os.environ["http_proxy"] = "http://your.proxy.server:port"
+    #os.environ["https_proxy"] = "http://your.proxy.server:port"
+
+    # Set up command line argument parsing
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-p','--pro_seq', type=str, default=None, help='Input a protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.')
+    parser.add_argument('-f','--fasta', type=str, default=None, help='Input a FASTA file. Default value is None. Only one of -p and -f should be specified.')
+    parser.add_argument('-l','--ligand_prompt', type=str, default='', help='Input a ligand prompt. Default value is an empty string.')
+    parser.add_argument('-e','--empty_input', action='store_true', default=False, help='Enable directly generate mode.')
+    parser.add_argument('-n','--number',type=int, default=100, help='At least how many molecules will be generated. Default value is 100.')
+    parser.add_argument('-d','--device',type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.")
+    parser.add_argument('-o','--output', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.")
+    parser.add_argument('-b','--batch_size', type=int, default=16, help="How many molecules will be generated per batch. Try to reduce this value if you have low RAM. Default value is 16.")
+    parser.add_argument('-t','--temperature', type=float, default=1.0, help="Adjusts the randomness of text generation; higher values produce more diverse outputs. Default value is 1.0.")
+    parser.add_argument('--top_k', type=int, default=9, help='The number of highest probability tokens to consider for top-k sampling. Defaults to 9.')
+    parser.add_argument('--top_p', type=float, default=0.9, help='The cumulative probability threshold (0.0 - 1.0) for top-p (nucleus) sampling. It defines the minimum subset of tokens to consider for random sampling. Defaults to 0.9.')
+    parser.add_argument('--min_atoms', type=int, default=None, help='Minimum number of non-H atoms allowed for generation.')
+    parser.add_argument('--max_atoms', type=int, default=35, help='Maximum number of non-H atoms allowed for generation. Default value is 35.')
+    parser.add_argument('--no_limit', action='store_true', default=False, help='Disable the default max atoms limit.')
+
+
+    args = parser.parse_args()
+    protein_seq = args.pro_seq
+    fasta_file = args.fasta
+    ligand_prompt = args.ligand_prompt
+    directly_gen = args.empty_input
+    num_generated = args.number
+    device = args.device
+    output_path = args.output
+    batch_generated_size = args.batch_size
+    temperature_value = args.temperature
+    top_k = args.top_k
+    top_p = args.top_p
+    min_atoms = args.min_atoms
+    max_atoms = args.max_atoms
+
+    if args.no_limit:
+        max_atoms = None
+    
+    if (args.min_atoms is not None) and (args.max_atoms is not None) and (args.min_atoms > args.max_atoms):
+        raise ValueError("Error: min_atoms cannot be greater than max_atoms.")
+    
+    if args.ligand_prompt:
+        args.max_atoms = None
+        args.min_atoms = None
+        print("Note: --ligand_prompt is specified. --max_atoms and --min_atoms settings will be ignored.")
+    
+    logging.basicConfig(level=logging.CRITICAL)
+    openbabel.obErrorLog.StopLogging()
+    os.makedirs(output_path, exist_ok=True)
+    # Check if the input is either a protein amino acid sequence or a FASTA file, but not both
+    if directly_gen:
+        print("Now in directly generate mode.")
+        prompt = "<|startoftext|><P>"
+        print(prompt)
+    else:
+        if (not protein_seq) and (not fasta_file):
+            print("Error: Input is empty.")
+            sys.exit(1)
+        if protein_seq and fasta_file:
+            print("Error: The input should be either a protein amino acid sequence or a FASTA file, but not both.")
+            sys.exit(1)
+        if fasta_file:
+            protein_seq = read_fasta_file(fasta_file)
+        # Generate a prompt for the model
+        p_prompt = "<|startoftext|><P>" + protein_seq + "<L>"
+        l_prompt = "" + ligand_prompt
+        prompt = p_prompt + l_prompt
+        print(prompt)
+
+
+    # Load the tokenizer and the model
+    tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt')
+    model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt")
+
+
+    model.eval()
+    device = torch.device(device)
+    model.to(device)
+
+    # Create a LigandPostprocessor object
+    ligand_post_processor = LigandPostprocessor(output_path)
+
+    # Generate molecules
+    generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
+    generated = generated.to(device)
+
+    batch_number = 0
+
+    directly_gen_protein_list = []
+    directly_gen_ligand_list = []
+    
+
+    attention_mask = generated.ne(tokenizer.pad_token_id).float()
+    while len(ligand_post_processor.hash_ligand_mapping) < num_generated:
+        generate_ligand_list = []
+        batch_number += 1
+        print(f"=====Batch {batch_number}=====")
+        print("Generating ligand SMILES ...")
+        sample_outputs = model.generate(
+            generated,
+            do_sample=True,
+            top_k=top_k,
+            max_length=1024,
+            top_p=top_p,
+            temperature=temperature_value,
+            num_return_sequences=batch_generated_size, 
+            attention_mask=attention_mask,
+            pad_token_id = tokenizer.eos_token_id
+        )
+        for sample_output in sample_outputs:
+            generate_ligand = tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1]
+            generate_ligand_list.append(generate_ligand)
+            if directly_gen:
+                directly_gen_protein_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[0])
+                directly_gen_ligand_list.append(generate_ligand)
+        torch.cuda.empty_cache()
+        ligand_post_processor.to_sdf(generate_ligand_list)
+        ligand_post_processor.delete_empty_files()
+        ligand_post_processor.check_sdf()
+        
+    if directly_gen:
+        arr = np.array([directly_gen_protein_list, directly_gen_ligand_list])
+        processed_ligand_list = ligand_post_processor.hash_ligand_mapping.values()
+        with open(os.path.join(output_path, 'generate_directly.csv'), 'w', newline='') as f:
+            writer = csv.writer(f)
+            for index in range(arr.shape[1]):
+                protein, ligand = arr[0, index], arr[1, index]
+                if ligand in processed_ligand_list:
+                    writer.writerow([protein, ligand])
+
+    print("Saving mapping file ...")
+    ligand_post_processor.save_mapping()
+    print(f"{len(ligand_post_processor.hash_ligand_mapping)} molecules successfully generated!")
+
+    print("Ligand Energy Minimization")
+    result = subprocess.run(['python', 'druggpt_min_multi.py', '-d', output_path])