|
a |
|
b/Code/Drug Discovery/ProtBert/Protein Fill-Mask, kkawchak.ipynb |
|
|
1 |
{ |
|
|
2 |
"nbformat": 4, |
|
|
3 |
"nbformat_minor": 0, |
|
|
4 |
"metadata": { |
|
|
5 |
"colab": { |
|
|
6 |
"provenance": [], |
|
|
7 |
"machine_shape": "hm" |
|
|
8 |
}, |
|
|
9 |
"kernelspec": { |
|
|
10 |
"name": "python3", |
|
|
11 |
"display_name": "Python 3" |
|
|
12 |
}, |
|
|
13 |
"language_info": { |
|
|
14 |
"name": "python" |
|
|
15 |
} |
|
|
16 |
}, |
|
|
17 |
"cells": [ |
|
|
18 |
{ |
|
|
19 |
"cell_type": "code", |
|
|
20 |
"execution_count": 8, |
|
|
21 |
"metadata": { |
|
|
22 |
"id": "-oTEyfkwNdmj" |
|
|
23 |
}, |
|
|
24 |
"outputs": [], |
|
|
25 |
"source": [ |
|
|
26 |
"# from transformers import pipeline, BertForMaskedLM, BertTokenizer, BertModel\n", |
|
|
27 |
"# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n", |
|
|
28 |
"# import time\n", |
|
|
29 |
"# import torch\n", |
|
|
30 |
"# import math\n", |
|
|
31 |
"# import re" |
|
|
32 |
] |
|
|
33 |
}, |
|
|
34 |
{ |
|
|
35 |
"cell_type": "code", |
|
|
36 |
"source": [ |
|
|
37 |
"# Ref 1: Hugging Face, Rostlab/prot_bert. https://huggingface.co/Rostlab/prot_bert\n", |
|
|
38 |
"# Ref 2: Hayes, J. Medium, 2023 https://medium.com/labs-notebook/large-language-models-for-drug-discovery-7ddfc005e0bb\n", |
|
|
39 |
"# Ref 3: HF Intro, AssemblyAI 2022 https://www.youtube.com/watch?v=QEaBAZQCtwE&t=4s\n", |
|
|
40 |
"# Ref 4: ChatGPT3.5 Coding assistance 2024 https://chat.openai.com/" |
|
|
41 |
], |
|
|
42 |
"metadata": { |
|
|
43 |
"id": "jzr3h73L1w-7" |
|
|
44 |
}, |
|
|
45 |
"execution_count": 9, |
|
|
46 |
"outputs": [] |
|
|
47 |
}, |
|
|
48 |
{ |
|
|
49 |
"cell_type": "code", |
|
|
50 |
"source": [ |
|
|
51 |
"seconds = time.time()\n", |
|
|
52 |
"print(\"Time in seconds since beginning of run:\", seconds)\n", |
|
|
53 |
"local_time = time.ctime(seconds)\n", |
|
|
54 |
"print(local_time)" |
|
|
55 |
], |
|
|
56 |
"metadata": { |
|
|
57 |
"colab": { |
|
|
58 |
"base_uri": "https://localhost:8080/", |
|
|
59 |
"height": 0 |
|
|
60 |
}, |
|
|
61 |
"id": "2K0dcoVBuBGy", |
|
|
62 |
"outputId": "8a904c07-4c64-41e8-94d1-0ec72d15c425" |
|
|
63 |
}, |
|
|
64 |
"execution_count": 10, |
|
|
65 |
"outputs": [ |
|
|
66 |
{ |
|
|
67 |
"output_type": "stream", |
|
|
68 |
"name": "stdout", |
|
|
69 |
"text": [ |
|
|
70 |
"Time in seconds since beginning of run: 1712858775.0148897\n", |
|
|
71 |
"Thu Apr 11 18:06:15 2024\n" |
|
|
72 |
] |
|
|
73 |
} |
|
|
74 |
] |
|
|
75 |
}, |
|
|
76 |
{ |
|
|
77 |
"cell_type": "code", |
|
|
78 |
"source": [ |
|
|
79 |
"# Load tokenizer and model\n", |
|
|
80 |
"tokenizer = BertTokenizer.from_pretrained(\"Rostlab/prot_bert\", do_lower_case=False)\n", |
|
|
81 |
"model = BertForMaskedLM.from_pretrained(\"Rostlab/prot_bert\")\n", |
|
|
82 |
"unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)\n", |
|
|
83 |
"\n", |
|
|
84 |
"# Define the sequence with two masked tokens\n", |
|
|
85 |
"sequence = 'FYLSITIHRPLRP[MASK]SSSSFLSLCLSLLSISIYYPS\\nLLIRRFTSISSCSSITIYHPLLYPSPSSLFLSLSHTYIYISPLHPSSLLLSISLLFYLSI\\nYIIYPLQPSSLLLSI[MASK]SLPLSISIYLSYPPLSSPSPSLSLYLTPFLLIPSLSIYLSLPFPY\\nHSYLYLRLLFHPPLPLHICHLPHSLTLFIFLLPPHLSHLPILFSRLQPFYPSTSPSSYRP\\nLPCIPSASYFSYHPLSPPPSLHPHPLSYPSVSRPSPPYLSIHLHSPPPPPPPSPFSSIHP\\nPFLSSTLPLPSSTSSLPPSSSPFSSTHLIPSPSSPPPPSLLPSSLPL'\n", |
|
|
86 |
"\n", |
|
|
87 |
"# Predict and print top 10 predictions for each mask\n", |
|
|
88 |
"result = unmasker(sequence, top_k=10)\n", |
|
|
89 |
"for predictions in result:\n", |
|
|
90 |
" for prediction in predictions:\n", |
|
|
91 |
" print(prediction)\n", |
|
|
92 |
" print(\"------\")" |
|
|
93 |
], |
|
|
94 |
"metadata": { |
|
|
95 |
"id": "YX05PGDZNImk", |
|
|
96 |
"colab": { |
|
|
97 |
"base_uri": "https://localhost:8080/", |
|
|
98 |
"height": 0 |
|
|
99 |
}, |
|
|
100 |
"outputId": "25f2ab86-fc66-4279-cf8f-a1621d267e6f" |
|
|
101 |
}, |
|
|
102 |
"execution_count": 11, |
|
|
103 |
"outputs": [ |
|
|
104 |
{ |
|
|
105 |
"output_type": "stream", |
|
|
106 |
"name": "stderr", |
|
|
107 |
"text": [ |
|
|
108 |
"Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", |
|
|
109 |
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", |
|
|
110 |
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" |
|
|
111 |
] |
|
|
112 |
}, |
|
|
113 |
{ |
|
|
114 |
"output_type": "stream", |
|
|
115 |
"name": "stdout", |
|
|
116 |
"text": [ |
|
|
117 |
"{'score': 0.9999492168426514, 'token': 25, 'token_str': 'X', 'sequence': '[CLS] [UNK] X [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
118 |
"{'score': 7.70106453273911e-06, 'token': 13, 'token_str': 'R', 'sequence': '[CLS] [UNK] R [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
119 |
"{'score': 6.548679721163353e-06, 'token': 5, 'token_str': 'L', 'sequence': '[CLS] [UNK] L [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
120 |
"{'score': 4.85780628878274e-06, 'token': 19, 'token_str': 'F', 'sequence': '[CLS] [UNK] F [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
121 |
"{'score': 4.367872861621436e-06, 'token': 18, 'token_str': 'Q', 'sequence': '[CLS] [UNK] Q [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
122 |
"{'score': 4.232546871207887e-06, 'token': 7, 'token_str': 'G', 'sequence': '[CLS] [UNK] G [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
123 |
"{'score': 3.1130280149227474e-06, 'token': 12, 'token_str': 'K', 'sequence': '[CLS] [UNK] K [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
124 |
"{'score': 2.9373295546974987e-06, 'token': 16, 'token_str': 'P', 'sequence': '[CLS] [UNK] P [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
125 |
"{'score': 2.8243982796993805e-06, 'token': 6, 'token_str': 'A', 'sequence': '[CLS] [UNK] A [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
126 |
"{'score': 2.3990457975742174e-06, 'token': 15, 'token_str': 'T', 'sequence': '[CLS] [UNK] T [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
127 |
"------\n", |
|
|
128 |
"{'score': 0.9999690055847168, 'token': 25, 'token_str': 'X', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] X [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
129 |
"{'score': 4.743668341689045e-06, 'token': 5, 'token_str': 'L', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] L [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
130 |
"{'score': 4.336807705840329e-06, 'token': 13, 'token_str': 'R', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] R [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
131 |
"{'score': 4.116420768696116e-06, 'token': 18, 'token_str': 'Q', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] Q [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
132 |
"{'score': 3.6834383081441047e-06, 'token': 6, 'token_str': 'A', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] A [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
133 |
"{'score': 2.364193051107577e-06, 'token': 7, 'token_str': 'G', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] G [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
134 |
"{'score': 2.1007083432778018e-06, 'token': 16, 'token_str': 'P', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] P [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
135 |
"{'score': 1.6038892454162124e-06, 'token': 15, 'token_str': 'T', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] T [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
136 |
"{'score': 1.3503024547389941e-06, 'token': 10, 'token_str': 'S', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] S [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
137 |
"{'score': 1.272862732548674e-06, 'token': 12, 'token_str': 'K', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] K [UNK] [UNK] [UNK] [UNK] [SEP]'}\n", |
|
|
138 |
"------\n" |
|
|
139 |
] |
|
|
140 |
} |
|
|
141 |
] |
|
|
142 |
}, |
|
|
143 |
{ |
|
|
144 |
"cell_type": "code", |
|
|
145 |
"source": [ |
|
|
146 |
"# Load tokenizer and model\n", |
|
|
147 |
"tokenizer = BertTokenizer.from_pretrained(\"Rostlab/prot_bert\", do_lower_case=False)\n", |
|
|
148 |
"model = BertForMaskedLM.from_pretrained(\"Rostlab/prot_bert\")\n", |
|
|
149 |
"unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)\n", |
|
|
150 |
"\n", |
|
|
151 |
"# Define the sequence with two masked tokens\n", |
|
|
152 |
"sequence = 'FYLSITIHRPLRP[MASK]SSSSFLSLCLSLLSISIYYPS\\nLLIRRFTSISSCSSITIYHPLLYPSPSSLFLSLSHTYIYISPLHPSSLLLSISLLFYLSI\\nYIIYPLQPSSLLLSI[MASK]SLPLSISIYLSYPPLSSPSPSLSLYLTPFLLIPSLSIYLSLPFPY\\nHSYLYLRLLFHPPLPLHICHLPHSLTLFIFLLPPHLSHLPILFSRLQPFYPSTSPSSYRP\\nLPCIPSASYFSYHPLSPPPSLHPHPLSYPSVSRPSPPYLSIHLHSPPPPPPPSPFSSIHP\\nPFLSSTLPLPSSTSSLPPSSSPFSSTHLIPSPSSPPPPSLLP[MASK]SSLPL'\n", |
|
|
153 |
"\n", |
|
|
154 |
"# Predict and print top 10 predictions for each mask\n", |
|
|
155 |
"result = unmasker(sequence, top_k=10)\n", |
|
|
156 |
"for predictions in result:\n", |
|
|
157 |
" for prediction in predictions:\n", |
|
|
158 |
" print(prediction)\n", |
|
|
159 |
" print(\"------\")" |
|
|
160 |
], |
|
|
161 |
"metadata": { |
|
|
162 |
"colab": { |
|
|
163 |
"base_uri": "https://localhost:8080/", |
|
|
164 |
"height": 0 |
|
|
165 |
}, |
|
|
166 |
"id": "McDmgTTdNlA9", |
|
|
167 |
"outputId": "551fd198-7b8b-4601-cf93-40dc5da22f95" |
|
|
168 |
}, |
|
|
169 |
"execution_count": 12, |
|
|
170 |
"outputs": [ |
|
|
171 |
{ |
|
|
172 |
"output_type": "stream", |
|
|
173 |
"name": "stderr", |
|
|
174 |
"text": [ |
|
|
175 |
"Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", |
|
|
176 |
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", |
|
|
177 |
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" |
|
|
178 |
] |
|
|
179 |
}, |
|
|
180 |
{ |
|
|
181 |
"output_type": "stream", |
|
|
182 |
"name": "stdout", |
|
|
183 |
"text": [ |
|
|
184 |
"{'score': 0.9999243021011353, 'token': 25, 'token_str': 'X', 'sequence': '[CLS] [UNK] X [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
185 |
"{'score': 9.95707159745507e-06, 'token': 5, 'token_str': 'L', 'sequence': '[CLS] [UNK] L [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
186 |
"{'score': 7.378088412224315e-06, 'token': 7, 'token_str': 'G', 'sequence': '[CLS] [UNK] G [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
187 |
"{'score': 7.330473636102397e-06, 'token': 9, 'token_str': 'E', 'sequence': '[CLS] [UNK] E [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
188 |
"{'score': 6.249400939850602e-06, 'token': 12, 'token_str': 'K', 'sequence': '[CLS] [UNK] K [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
189 |
"{'score': 6.2195363170758355e-06, 'token': 19, 'token_str': 'F', 'sequence': '[CLS] [UNK] F [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
190 |
"{'score': 6.003328053338919e-06, 'token': 6, 'token_str': 'A', 'sequence': '[CLS] [UNK] A [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
191 |
"{'score': 5.373394287744304e-06, 'token': 13, 'token_str': 'R', 'sequence': '[CLS] [UNK] R [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
192 |
"{'score': 4.275994342606282e-06, 'token': 20, 'token_str': 'Y', 'sequence': '[CLS] [UNK] Y [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
193 |
"{'score': 3.627200385381002e-06, 'token': 8, 'token_str': 'V', 'sequence': '[CLS] [UNK] V [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
194 |
"------\n", |
|
|
195 |
"{'score': 0.9996041655540466, 'token': 25, 'token_str': 'X', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] X [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
196 |
"{'score': 5.5235712352441624e-05, 'token': 6, 'token_str': 'A', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] A [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
197 |
"{'score': 5.422084723250009e-05, 'token': 5, 'token_str': 'L', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] L [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
198 |
"{'score': 4.226986857247539e-05, 'token': 7, 'token_str': 'G', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] G [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
199 |
"{'score': 3.9134716644184664e-05, 'token': 13, 'token_str': 'R', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] R [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
200 |
"{'score': 3.062105315621011e-05, 'token': 18, 'token_str': 'Q', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] Q [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
201 |
"{'score': 2.8334136004559696e-05, 'token': 9, 'token_str': 'E', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] E [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
202 |
"{'score': 2.7488824343890883e-05, 'token': 12, 'token_str': 'K', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] K [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
203 |
"{'score': 1.9406317733228207e-05, 'token': 16, 'token_str': 'P', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] P [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
204 |
"{'score': 1.778588557499461e-05, 'token': 19, 'token_str': 'F', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] F [UNK] [UNK] [UNK] [UNK] [MASK] [UNK] [SEP]'}\n", |
|
|
205 |
"------\n", |
|
|
206 |
"{'score': 0.9990819692611694, 'token': 25, 'token_str': 'X', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] X [UNK] [SEP]'}\n", |
|
|
207 |
"{'score': 0.00016688588948454708, 'token': 6, 'token_str': 'A', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] A [UNK] [SEP]'}\n", |
|
|
208 |
"{'score': 8.889797754818574e-05, 'token': 5, 'token_str': 'L', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] L [UNK] [SEP]'}\n", |
|
|
209 |
"{'score': 8.658925071358681e-05, 'token': 12, 'token_str': 'K', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] K [UNK] [SEP]'}\n", |
|
|
210 |
"{'score': 6.849779310869053e-05, 'token': 18, 'token_str': 'Q', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] Q [UNK] [SEP]'}\n", |
|
|
211 |
"{'score': 6.579031469300389e-05, 'token': 20, 'token_str': 'Y', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] Y [UNK] [SEP]'}\n", |
|
|
212 |
"{'score': 6.401719292625785e-05, 'token': 7, 'token_str': 'G', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] G [UNK] [SEP]'}\n", |
|
|
213 |
"{'score': 5.076546221971512e-05, 'token': 9, 'token_str': 'E', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] E [UNK] [SEP]'}\n", |
|
|
214 |
"{'score': 4.9771988415159285e-05, 'token': 13, 'token_str': 'R', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] R [UNK] [SEP]'}\n", |
|
|
215 |
"{'score': 4.564020491670817e-05, 'token': 19, 'token_str': 'F', 'sequence': '[CLS] [UNK] [MASK] [UNK] [UNK] [UNK] [MASK] [UNK] [UNK] [UNK] [UNK] F [UNK] [SEP]'}\n", |
|
|
216 |
"------\n" |
|
|
217 |
] |
|
|
218 |
} |
|
|
219 |
] |
|
|
220 |
}, |
|
|
221 |
{ |
|
|
222 |
"cell_type": "code", |
|
|
223 |
"source": [ |
|
|
224 |
"# Load tokenizer and model\n", |
|
|
225 |
"tokenizer = BertTokenizer.from_pretrained(\"Rostlab/prot_bert\", do_lower_case=False)\n", |
|
|
226 |
"model = BertModel.from_pretrained(\"Rostlab/prot_bert\")\n", |
|
|
227 |
"\n", |
|
|
228 |
"# Example protein sequence\n", |
|
|
229 |
"sequence_Example = \"FYLSITIHRPLRPSSSSFLSLCLSLLSISIYYPS\\nLLIRRFTSISSCSSITIYHPLLYPSPSSLFLSLSHTYIYISPLHPSSLLLSISLLFYLSI\\nYIIYPLQPSSLLLSISLPLSISIYLSYPPLSSPSPSLSLYLTPFLLIPSLSIYLSLPFPY\\nHSYLYLRLLFHPPLPLHICHLPHSLTLFIFLLPPHLSHLPILFSRLQPFYPSTSPSSYRP\\nLPCIPSASYFSYHPLSPPPSLHPHPLSYPSVSRPSPPYLSIHLHSPPPPPPPSPFSSIHP\\nPFLSSTLPLPSSTSSLPPSSSPFSSTHLIPSPSSPPPPSLLPSSLPL\"\n", |
|
|
230 |
"sequence_Example = re.sub(r\"[UZOB]\", \"X\", sequence_Example)\n", |
|
|
231 |
"\n", |
|
|
232 |
"# Tokenize input sequence\n", |
|
|
233 |
"encoded_input = tokenizer(sequence_Example, return_tensors='pt')\n", |
|
|
234 |
"\n", |
|
|
235 |
"# Forward pass through the model\n", |
|
|
236 |
"output = model(**encoded_input)\n", |
|
|
237 |
"\n", |
|
|
238 |
"print(output)\n", |
|
|
239 |
"\n", |
|
|
240 |
"# Extract last hidden states\n", |
|
|
241 |
"last_hidden_state = output.last_hidden_state\n", |
|
|
242 |
"\n", |
|
|
243 |
"# Calculate additional metrics\n", |
|
|
244 |
"mean_tensor = torch.mean(last_hidden_state, dim=1) # Mean along sequence length\n", |
|
|
245 |
"std_tensor = torch.std(last_hidden_state, dim=1) # Standard deviation along sequence length\n", |
|
|
246 |
"max_tensor = torch.max(last_hidden_state, dim=1).values # Maximum value along sequence length\n", |
|
|
247 |
"\n", |
|
|
248 |
"# Print additional metrics\n", |
|
|
249 |
"print(\"Mean tensor shape:\", mean_tensor.shape)\n", |
|
|
250 |
"print(\"Standard deviation tensor shape:\", std_tensor.shape)\n", |
|
|
251 |
"print(\"Maximum tensor shape:\", max_tensor.shape)" |
|
|
252 |
], |
|
|
253 |
"metadata": { |
|
|
254 |
"id": "tiXjSEo2MfrD", |
|
|
255 |
"colab": { |
|
|
256 |
"base_uri": "https://localhost:8080/", |
|
|
257 |
"height": 0 |
|
|
258 |
}, |
|
|
259 |
"outputId": "55f75b62-24fc-4524-c3fe-04c076842cdd" |
|
|
260 |
}, |
|
|
261 |
"execution_count": 13, |
|
|
262 |
"outputs": [ |
|
|
263 |
{ |
|
|
264 |
"output_type": "stream", |
|
|
265 |
"name": "stdout", |
|
|
266 |
"text": [ |
|
|
267 |
"BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.1217, 0.0759, -0.2128, ..., 0.0417, -0.0052, 0.1067],\n", |
|
|
268 |
" [-0.0694, 0.1312, -0.1388, ..., -0.0231, 0.0595, 0.0174],\n", |
|
|
269 |
" [-0.1399, 0.0801, -0.1725, ..., 0.0329, -0.0341, 0.0568],\n", |
|
|
270 |
" ...,\n", |
|
|
271 |
" [-0.1345, 0.0499, -0.1611, ..., 0.0035, -0.0035, 0.0531],\n", |
|
|
272 |
" [-0.1569, 0.0764, -0.1574, ..., -0.0122, -0.0043, 0.0533],\n", |
|
|
273 |
" [-0.1272, 0.0725, -0.1479, ..., -0.0294, 0.0445, -0.0033]]],\n", |
|
|
274 |
" grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.1908, 0.2020, -0.1812, ..., 0.1921, 0.1790, -0.1964]],\n", |
|
|
275 |
" grad_fn=<TanhBackward0>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)\n", |
|
|
276 |
"Mean tensor shape: torch.Size([1, 1024])\n", |
|
|
277 |
"Standard deviation tensor shape: torch.Size([1, 1024])\n", |
|
|
278 |
"Maximum tensor shape: torch.Size([1, 1024])\n" |
|
|
279 |
] |
|
|
280 |
} |
|
|
281 |
] |
|
|
282 |
}, |
|
|
283 |
{ |
|
|
284 |
"cell_type": "code", |
|
|
285 |
"source": [ |
|
|
286 |
"seconds = time.time()\n", |
|
|
287 |
"print(\"Time in seconds since end of run:\", seconds)\n", |
|
|
288 |
"local_time = time.ctime(seconds)\n", |
|
|
289 |
"print(local_time)" |
|
|
290 |
], |
|
|
291 |
"metadata": { |
|
|
292 |
"colab": { |
|
|
293 |
"base_uri": "https://localhost:8080/", |
|
|
294 |
"height": 0 |
|
|
295 |
}, |
|
|
296 |
"id": "TkiND_HHuERP", |
|
|
297 |
"outputId": "c8fbb5be-e846-42ef-d6bc-6347d9a3ba87" |
|
|
298 |
}, |
|
|
299 |
"execution_count": 14, |
|
|
300 |
"outputs": [ |
|
|
301 |
{ |
|
|
302 |
"output_type": "stream", |
|
|
303 |
"name": "stdout", |
|
|
304 |
"text": [ |
|
|
305 |
"Time in seconds since end of run: 1712858783.8144772\n", |
|
|
306 |
"Thu Apr 11 18:06:23 2024\n" |
|
|
307 |
] |
|
|
308 |
} |
|
|
309 |
] |
|
|
310 |
} |
|
|
311 |
] |
|
|
312 |
} |