Diff of /demo_cRNN.ipynb [000000] .. [58db57]

Switch to unified view

a b/demo_cRNN.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "%load_ext autoreload\n",
10
    "%autoreload 2\n",
11
    "\n",
12
    "\n",
13
    "import numpy as np\n",
14
    "import rdkit\n",
15
    "from rdkit import Chem\n",
16
    "\n",
17
    "import h5py, ast, pickle\n",
18
    "\n",
19
    "# Occupy a GPU for the model to be loaded \n",
20
    "%env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
21
    "# GPU ID, if occupied change to an available GPU ID listed under !nvidia-smi\n",
22
    "%env CUDA_VISIBLE_DEVICES=2 \n",
23
    "\n",
24
    "from ddc_pub import ddc_v3 as ddc"
25
   ]
26
  },
27
  {
28
   "cell_type": "code",
29
   "execution_count": null,
30
   "metadata": {},
31
   "outputs": [],
32
   "source": [
33
    "def get_descriptors(smiles_list, qsar_model=None, show_actives=False, active_thresh=0.5, qed_thresh=0.5):\n",
34
    "    \"\"\"Calculate molecular descriptors of SMILES in a list.\n",
35
    "    The descriptors are logp, tpsa, mw, qed, hba, hbd and probability of being active towards DRD2.\n",
36
    "    \n",
37
    "    Returns:\n",
38
    "        A np.ndarray of descriptors.\n",
39
    "    \"\"\"\n",
40
    "    from tqdm import tqdm_notebook as tqdm\n",
41
    "    import rdkit\n",
42
    "    from rdkit import Chem, DataStructs\n",
43
    "    from rdkit.Chem import Descriptors, rdMolDescriptors, AllChem, QED\n",
44
    "    \n",
45
    "    descriptors = []\n",
46
    "    active_mols = []\n",
47
    "    \n",
48
    "    for idx, smiles in enumerate(smiles_list):\n",
49
    "        # Convert to mol\n",
50
    "        mol = Chem.MolFromSmiles(smiles)\n",
51
    "        # If valid, calculate its properties\n",
52
    "        if mol:\n",
53
    "            try:\n",
54
    "                logp  = Descriptors.MolLogP(mol)\n",
55
    "                tpsa  = Descriptors.TPSA(mol)\n",
56
    "                molwt = Descriptors.ExactMolWt(mol)\n",
57
    "                hba   = rdMolDescriptors.CalcNumHBA(mol)\n",
58
    "                hbd   = rdMolDescriptors.CalcNumHBD(mol)\n",
59
    "                qed   = QED.qed(mol)\n",
60
    "                \n",
61
    "                # Calculate fingerprints\n",
62
    "                fp = AllChem.GetMorganFingerprintAsBitVect(mol,2, nBits=2048)\n",
63
    "                ecfp4 = np.zeros((2048,))\n",
64
    "                DataStructs.ConvertToNumpyArray(fp, ecfp4) \n",
65
    "                # Predict activity and pick only the second component\n",
66
    "                active = qsar_model.predict_proba([ecfp4])[0][1]\n",
67
    "                descriptors.append([logp, tpsa, molwt, qed, hba, hbd, active]) \n",
68
    "                \n",
69
    "                if active > active_thresh and qed > qed_thresh:\n",
70
    "                    if show_actives:\n",
71
    "                        active_mols.append(mol)\n",
72
    "                        print(\"active_proba: %.2f, QED: %.2f.\" % (active, qed))\n",
73
    "                        display(mol)\n",
74
    "                        pass\n",
75
    "                \n",
76
    "            except Exception as e:\n",
77
    "                # Sanitization error: Explicit valence for atom # 17 N, 4, is greater than permitted\n",
78
    "                print(e)\n",
79
    "        # Else, return None\n",
80
    "        else:\n",
81
    "            print(\"Invalid generation.\")\n",
82
    "            \n",
83
    "    return np.asarray(descriptors)"
84
   ]
85
  },
86
  {
87
   "cell_type": "markdown",
88
   "metadata": {},
89
   "source": [
90
    "# Load QSAR model"
91
   ]
92
  },
93
  {
94
   "cell_type": "code",
95
   "execution_count": null,
96
   "metadata": {},
97
   "outputs": [],
98
   "source": [
99
    "qsar_model_name = \"models/qsar_model.pickle\"\n",
100
    "with open(qsar_model_name, \"rb\") as file:\n",
101
    "    qsar_model = pickle.load(file)[\"classifier_sv\"]"
102
   ]
103
  },
104
  {
105
   "cell_type": "markdown",
106
   "metadata": {},
107
   "source": [
108
    "# Load PCB cRNN"
109
   ]
110
  },
111
  {
112
   "cell_type": "code",
113
   "execution_count": null,
114
   "metadata": {},
115
   "outputs": [],
116
   "source": [
117
    "# Import existing (trained) model\n",
118
    "# Ignore any warning(s) about training configuration or non-seriazable keyword arguments\n",
119
    "model_name = \"models/pcb_model\"\n",
120
    "model = ddc.DDC(model_name=model_name)"
121
   ]
122
  },
123
  {
124
   "cell_type": "markdown",
125
   "metadata": {},
126
   "source": [
127
    "# Select conditions for generated molecules"
128
   ]
129
  },
130
  {
131
   "cell_type": "code",
132
   "execution_count": null,
133
   "metadata": {},
134
   "outputs": [],
135
   "source": [
136
    "# Custom conditions\n",
137
    "logp              = 3.5\n",
138
    "tpsa              = 70.0\n",
139
    "mw                = 350.0\n",
140
    "qed               = 0.8\n",
141
    "hba               = 4.0\n",
142
    "hbd               = 1.0\n",
143
    "drd2_active_proba = 0.9\n",
144
    "\n",
145
    "target = np.array([logp, tpsa, mw, qed, hba, hbd, drd2_active_proba])"
146
   ]
147
  },
148
  {
149
   "cell_type": "code",
150
   "execution_count": null,
151
   "metadata": {},
152
   "outputs": [],
153
   "source": [
154
    "# Convert back to SMILES\n",
155
    "smiles_out, _ = model.predict(latent=target, temp=0) # Change temp to 1 for more funky results\n",
156
    "\n",
157
    "# Calculate the properties of the generated structure and compare\n",
158
    "get_descriptors(smiles_list=[smiles_out], qsar_model=qsar_model, show_actives=True)"
159
   ]
160
  }
161
 ],
162
 "metadata": {
163
  "kernelspec": {
164
   "display_name": "ddc_env (python_3.6.7)",
165
   "language": "python",
166
   "name": "ddc_env"
167
  },
168
  "language_info": {
169
   "codemirror_mode": {
170
    "name": "ipython",
171
    "version": 3
172
   },
173
   "file_extension": ".py",
174
   "mimetype": "text/x-python",
175
   "name": "python",
176
   "nbconvert_exporter": "python",
177
   "pygments_lexer": "ipython3",
178
   "version": "3.6.7"
179
  }
180
 },
181
 "nbformat": 4,
182
 "nbformat_minor": 4
183
}