Switch to unified view

a b/notebooks/inference_base.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "id": "8ae13342-5979-42df-9fea-5bb97c60be23",
7
   "metadata": {},
8
   "outputs": [],
9
   "source": [
10
    "import sys\n",
11
    "import os\n",
12
    "import time\n",
13
    "from pathlib import Path\n",
14
    "from typing import Any, Literal, Optional\n",
15
    "import json\n",
16
    "\n",
17
    "import lightning as L\n",
18
    "import torch\n",
19
    "import torch._dynamo.config\n",
20
    "import torch._inductor.config\n",
21
    "from lightning.fabric.plugins import BitsandbytesPrecision\n",
22
    "from lightning.fabric.strategies import FSDPStrategy\n",
23
    "\n",
24
    "## Add the lit_gpt folder to the path\n",
25
    "sys.path.insert(0, os.path.abspath('../'))\n",
26
    "\n",
27
    "from lit_gpt import GPT, Config, Tokenizer\n",
28
    "from lit_gpt.model import Block\n",
29
    "from lit_gpt.utils import (\n",
30
    "    check_valid_checkpoint_dir,\n",
31
    "    get_default_supported_precision,\n",
32
    "    gptq_quantization,\n",
33
    "    load_checkpoint,\n",
34
    ")"
35
   ]
36
  },
37
  {
38
   "cell_type": "code",
39
   "execution_count": null,
40
   "id": "7ac87e9a-6846-4f84-9c02-27f46fa417ce",
41
   "metadata": {
42
    "tags": []
43
   },
44
   "outputs": [],
45
   "source": [
46
    "torch.set_float32_matmul_precision(\"high\")"
47
   ]
48
  },
49
  {
50
   "cell_type": "code",
51
   "execution_count": null,
52
   "id": "e6eec0b2-81d9-4a88-b20c-0be57ec6b2a2",
53
   "metadata": {},
54
   "outputs": [],
55
   "source": [
56
    "with open('../data/entity_extraction/entity-extraction-test-data.json', 'r') as file:\n",
57
    "    test_data = json.load(file)"
58
   ]
59
  },
60
  {
61
   "cell_type": "code",
62
   "execution_count": null,
63
   "id": "4d92ad07",
64
   "metadata": {},
65
   "outputs": [],
66
   "source": [
67
    "example = {\n",
68
    "        \"input\": \"Natalie Cooper,\\nncooper@example.com\\n6789 Birch Street, Denver, CO 80203,\\n303-555-6543, United States\\n\\nRelationship to XYZ Pharma Inc.: Patient\\nReason for contacting: Adverse Event\\n\\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper\",\n",
69
    "        \"output\": \"{\\\"drug_name\\\": \\\"Abilify\\\", \\\"adverse_events\\\": [\\\"nausea\\\", \\\"vomiting\\\"]}\"\n",
70
    "    }"
71
   ]
72
  },
73
  {
74
   "cell_type": "code",
75
   "execution_count": null,
76
   "id": "3d5eed73-7020-466a-8b83-5ebc7d497698",
77
   "metadata": {},
78
   "outputs": [],
79
   "source": [
80
    "prompt = f\"\"\"Act as an expert Analyst with 20+ years of experience\\\n",
81
    "in Pharma and Healthcare industry. \\\n",
82
    "For the following provided input you need to generate the output which \\\n",
83
    "identifies and extracts entities like 'drug_name' and 'adverse_events' \\\n",
84
    "use the format:\\n\\\n",
85
    "{{'drug_name':'DRUG_NAME_HERE', 'adverse_events':[## List of symptoms here]}}\\n\\\n",
86
    "\n",
87
    "### Extract Entities from the follwing:\\n\\\n",
88
    "{example[\"input\"]}\\\n",
89
    "\n",
90
    "### Response:\n",
91
    "\"\"\""
92
   ]
93
  },
94
  {
95
   "cell_type": "code",
96
   "execution_count": null,
97
   "id": "c972c1d6-f2f3-4bf8-bd92-405bda4b02ab",
98
   "metadata": {},
99
   "outputs": [],
100
   "source": [
101
    "print(prompt)"
102
   ]
103
  },
104
  {
105
   "cell_type": "code",
106
   "execution_count": null,
107
   "id": "3ae2d36b-3f58-4be1-955d-cca8c8c5f7d6",
108
   "metadata": {},
109
   "outputs": [],
110
   "source": [
111
    "def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:\n",
112
    "    if torch._dynamo.is_compiling():\n",
113
    "        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly\n",
114
    "        distribution = torch.empty_like(probs).exponential_(1)\n",
115
    "        return torch.argmax(probs / distribution, dim=-1, keepdim=True)\n",
116
    "    return torch.multinomial(probs, num_samples=1)\n",
117
    "\n",
118
    "\n",
119
    "def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:\n",
120
    "    logits = logits[0, -1]\n",
121
    "    # optionally crop the logits to only the top k options\n",
122
    "    if top_k is not None:\n",
123
    "        v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n",
124
    "        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n",
125
    "        logits = torch.full_like(logits, float(\"-inf\")).scatter_(-1, i, v)\n",
126
    "    # optionally scale the logits and sample from a probability distribution\n",
127
    "    if temperature > 0.0:\n",
128
    "        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)\n",
129
    "        return multinomial_num_samples_1(probs)\n",
130
    "    return torch.argmax(logits, dim=-1, keepdim=True)\n",
131
    "\n",
132
    "def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:\n",
133
    "    logits = model(x, input_pos)\n",
134
    "    next = sample(logits, **kwargs)\n",
135
    "    return next.type_as(x)\n",
136
    "\n",
137
    "@torch.inference_mode()\n",
138
    "def generate(\n",
139
    "    model: GPT,\n",
140
    "    prompt: torch.Tensor,\n",
141
    "    max_returned_tokens: int,\n",
142
    "    *,\n",
143
    "    temperature: float = 1.0,\n",
144
    "    top_k: Optional[int] = None,\n",
145
    "    eos_id: Optional[int] = None,\n",
146
    ") -> torch.Tensor:\n",
147
    "    \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n",
148
    "\n",
149
    "    The implementation of this function is modified from A. Karpathy's nanoGPT.\n",
150
    "\n",
151
    "    Args:\n",
152
    "        model: The model to use.\n",
153
    "        prompt: Tensor of shape (T) with indices of the prompt sequence.\n",
154
    "        max_returned_tokens: The maximum number of tokens to return (given plus generated).\n",
155
    "        temperature: Scales the predicted logits by 1 / temperature.\n",
156
    "        top_k: If specified, only sample among the tokens with the k highest probabilities.\n",
157
    "        eos_id: If specified, stop generating any more token once the <eos> token is triggered.\n",
158
    "    \"\"\"\n",
159
    "    T = prompt.size(0)\n",
160
    "    assert max_returned_tokens > T\n",
161
    "    if model.max_seq_length < max_returned_tokens - 1:\n",
162
    "        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a\n",
163
    "        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do\n",
164
    "        # not support it to avoid negatively impacting the overall speed\n",
165
    "        raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n",
166
    "\n",
167
    "    device = prompt.device\n",
168
    "    tokens = [prompt]\n",
169
    "    input_pos = torch.tensor([T], device=device)\n",
170
    "    token = next_token(\n",
171
    "        model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k\n",
172
    "    ).clone()\n",
173
    "    tokens.append(token)\n",
174
    "    for _ in range(2, max_returned_tokens - T + 1):\n",
175
    "        token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone()\n",
176
    "        tokens.append(token)\n",
177
    "        if token == eos_id:\n",
178
    "            break\n",
179
    "        input_pos = input_pos.add_(1)\n",
180
    "    return torch.cat(tokens)"
181
   ]
182
  },
183
  {
184
   "cell_type": "code",
185
   "execution_count": null,
186
   "id": "e3c30b09-e631-44b8-a45b-ac79dbbab52f",
187
   "metadata": {
188
    "tags": []
189
   },
190
   "outputs": [],
191
   "source": [
192
    "print(\"[INFO] Using StableLM-3B base model\")\n",
193
    "checkpoint_dir: Path = Path(\"../checkpoints/stabilityai/stablelm-base-alpha-3b\")\n",
194
    "predictions_file_name = '../data/predictions-stablelm-base.json'\n",
195
    "\n",
196
    "quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\", \"gptq.int4\"]] = None\n",
197
    "max_new_tokens: int = 50\n",
198
    "top_k: int = 200\n",
199
    "temperature: float = 0.1\n",
200
    "strategy: str = \"auto\"\n",
201
    "devices: int = 1\n",
202
    "precision: Optional[str] = None\n",
203
    "num_samples: int = 1,"
204
   ]
205
  },
206
  {
207
   "cell_type": "code",
208
   "execution_count": null,
209
   "id": "b842b257-2c9c-4c64-b011-f97e215c3794",
210
   "metadata": {},
211
   "outputs": [],
212
   "source": [
213
    "if strategy == \"fsdp\":\n",
214
    "    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)\n",
215
    "fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)\n",
216
    "fabric.launch()\n",
217
    "\n",
218
    "check_valid_checkpoint_dir(checkpoint_dir)\n",
219
    "\n",
220
    "config = Config.from_json(checkpoint_dir / \"lit_config.json\")\n",
221
    "\n",
222
    "if quantize is not None and devices > 1:\n",
223
    "    raise NotImplementedError\n",
224
    "if quantize == \"gptq.int4\":\n",
225
    "    model_file = \"lit_model_gptq.4bit.pth\"\n",
226
    "    if not (checkpoint_dir / model_file).is_file():\n",
227
    "        raise ValueError(\"Please run `python quantize/gptq.py` first\")\n",
228
    "else:\n",
229
    "    model_file = \"lit_model.pth\"\n",
230
    "    \n",
231
    "checkpoint_path = checkpoint_dir / model_file\n",
232
    "tokenizer = Tokenizer(checkpoint_dir)\n",
233
    "encoded = tokenizer.encode(prompt, device=fabric.device)\n",
234
    "prompt_length = encoded.size(0)\n",
235
    "max_returned_tokens = prompt_length + max_new_tokens\n",
236
    "\n",
237
    "fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n",
238
    "t0 = time.perf_counter()\n",
239
    "with fabric.init_module(empty_init=True), gptq_quantization(quantize == \"gptq.int4\"):\n",
240
    "    model = GPT(config)\n",
241
    "fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n",
242
    "with fabric.init_tensor():\n",
243
    "    # set the max_seq_length to limit the memory usage to what we need\n",
244
    "    model.max_seq_length = max_returned_tokens\n",
245
    "    # enable the kv cache\n",
246
    "    model.set_kv_cache(batch_size=1)\n",
247
    "model.eval()\n",
248
    "\n",
249
    "model = fabric.setup_module(model)\n",
250
    "t0 = time.perf_counter()\n",
251
    "load_checkpoint(fabric, model, checkpoint_path)\n",
252
    "fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)"
253
   ]
254
  },
255
  {
256
   "cell_type": "code",
257
   "execution_count": null,
258
   "id": "59e00f69-d9fd-4cf5-a4ed-f142991b870f",
259
   "metadata": {
260
    "tags": []
261
   },
262
   "outputs": [],
263
   "source": [
264
    "prompt"
265
   ]
266
  },
267
  {
268
   "cell_type": "code",
269
   "execution_count": null,
270
   "id": "3eeaa237-7a95-4c9d-aeda-8a3c246d02fe",
271
   "metadata": {},
272
   "outputs": [],
273
   "source": [
274
    "L.seed_everything(1234)\n",
275
    "\n",
276
    "t0 = time.perf_counter()\n",
277
    "y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)\n",
278
    "t = time.perf_counter() - t0\n",
279
    "for block in model.transformer.h:\n",
280
    "    block.attn.kv_cache.reset_parameters()\n",
281
    "output = tokenizer.decode(y)\n",
282
    "fabric.print(output)\n",
283
    "tokens_generated = y.size(0) - prompt_length"
284
   ]
285
  }
286
 ],
287
 "metadata": {
288
  "kernelspec": {
289
   "display_name": "new",
290
   "language": "python",
291
   "name": "new"
292
  },
293
  "language_info": {
294
   "codemirror_mode": {
295
    "name": "ipython",
296
    "version": 3
297
   },
298
   "file_extension": ".py",
299
   "mimetype": "text/x-python",
300
   "name": "python",
301
   "nbconvert_exporter": "python",
302
   "pygments_lexer": "ipython3",
303
   "version": "3.10.13"
304
  }
305
 },
306
 "nbformat": 4,
307
 "nbformat_minor": 5
308
}