309 lines (308 with data), 11.4 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "8ae13342-5979-42df-9fea-5bb97c60be23",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import time\n",
"from pathlib import Path\n",
"from typing import Any, Literal, Optional\n",
"import json\n",
"\n",
"import lightning as L\n",
"import torch\n",
"import torch._dynamo.config\n",
"import torch._inductor.config\n",
"from lightning.fabric.plugins import BitsandbytesPrecision\n",
"from lightning.fabric.strategies import FSDPStrategy\n",
"\n",
"## Add the lit_gpt folder to the path\n",
"sys.path.insert(0, os.path.abspath('../'))\n",
"\n",
"from lit_gpt import GPT, Config, Tokenizer\n",
"from lit_gpt.model import Block\n",
"from lit_gpt.utils import (\n",
" check_valid_checkpoint_dir,\n",
" get_default_supported_precision,\n",
" gptq_quantization,\n",
" load_checkpoint,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ac87e9a-6846-4f84-9c02-27f46fa417ce",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"torch.set_float32_matmul_precision(\"high\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6eec0b2-81d9-4a88-b20c-0be57ec6b2a2",
"metadata": {},
"outputs": [],
"source": [
"with open('../data/entity_extraction/entity-extraction-test-data.json', 'r') as file:\n",
" test_data = json.load(file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d92ad07",
"metadata": {},
"outputs": [],
"source": [
"example = {\n",
" \"input\": \"Natalie Cooper,\\nncooper@example.com\\n6789 Birch Street, Denver, CO 80203,\\n303-555-6543, United States\\n\\nRelationship to XYZ Pharma Inc.: Patient\\nReason for contacting: Adverse Event\\n\\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper\",\n",
" \"output\": \"{\\\"drug_name\\\": \\\"Abilify\\\", \\\"adverse_events\\\": [\\\"nausea\\\", \\\"vomiting\\\"]}\"\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d5eed73-7020-466a-8b83-5ebc7d497698",
"metadata": {},
"outputs": [],
"source": [
"prompt = f\"\"\"Act as an expert Analyst with 20+ years of experience\\\n",
"in Pharma and Healthcare industry. \\\n",
"For the following provided input you need to generate the output which \\\n",
"identifies and extracts entities like 'drug_name' and 'adverse_events' \\\n",
"use the format:\\n\\\n",
"{{'drug_name':'DRUG_NAME_HERE', 'adverse_events':[## List of symptoms here]}}\\n\\\n",
"\n",
"### Extract Entities from the follwing:\\n\\\n",
"{example[\"input\"]}\\\n",
"\n",
"### Response:\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c972c1d6-f2f3-4bf8-bd92-405bda4b02ab",
"metadata": {},
"outputs": [],
"source": [
"print(prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ae2d36b-3f58-4be1-955d-cca8c8c5f7d6",
"metadata": {},
"outputs": [],
"source": [
"def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:\n",
" if torch._dynamo.is_compiling():\n",
" # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly\n",
" distribution = torch.empty_like(probs).exponential_(1)\n",
" return torch.argmax(probs / distribution, dim=-1, keepdim=True)\n",
" return torch.multinomial(probs, num_samples=1)\n",
"\n",
"\n",
"def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:\n",
" logits = logits[0, -1]\n",
" # optionally crop the logits to only the top k options\n",
" if top_k is not None:\n",
" v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n",
" # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n",
" logits = torch.full_like(logits, float(\"-inf\")).scatter_(-1, i, v)\n",
" # optionally scale the logits and sample from a probability distribution\n",
" if temperature > 0.0:\n",
" probs = torch.nn.functional.softmax(logits / temperature, dim=-1)\n",
" return multinomial_num_samples_1(probs)\n",
" return torch.argmax(logits, dim=-1, keepdim=True)\n",
"\n",
"def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:\n",
" logits = model(x, input_pos)\n",
" next = sample(logits, **kwargs)\n",
" return next.type_as(x)\n",
"\n",
"@torch.inference_mode()\n",
"def generate(\n",
" model: GPT,\n",
" prompt: torch.Tensor,\n",
" max_returned_tokens: int,\n",
" *,\n",
" temperature: float = 1.0,\n",
" top_k: Optional[int] = None,\n",
" eos_id: Optional[int] = None,\n",
") -> torch.Tensor:\n",
" \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n",
"\n",
" The implementation of this function is modified from A. Karpathy's nanoGPT.\n",
"\n",
" Args:\n",
" model: The model to use.\n",
" prompt: Tensor of shape (T) with indices of the prompt sequence.\n",
" max_returned_tokens: The maximum number of tokens to return (given plus generated).\n",
" temperature: Scales the predicted logits by 1 / temperature.\n",
" top_k: If specified, only sample among the tokens with the k highest probabilities.\n",
" eos_id: If specified, stop generating any more token once the <eos> token is triggered.\n",
" \"\"\"\n",
" T = prompt.size(0)\n",
" assert max_returned_tokens > T\n",
" if model.max_seq_length < max_returned_tokens - 1:\n",
" # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a\n",
" # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do\n",
" # not support it to avoid negatively impacting the overall speed\n",
" raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n",
"\n",
" device = prompt.device\n",
" tokens = [prompt]\n",
" input_pos = torch.tensor([T], device=device)\n",
" token = next_token(\n",
" model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k\n",
" ).clone()\n",
" tokens.append(token)\n",
" for _ in range(2, max_returned_tokens - T + 1):\n",
" token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone()\n",
" tokens.append(token)\n",
" if token == eos_id:\n",
" break\n",
" input_pos = input_pos.add_(1)\n",
" return torch.cat(tokens)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3c30b09-e631-44b8-a45b-ac79dbbab52f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(\"[INFO] Using StableLM-3B base model\")\n",
"checkpoint_dir: Path = Path(\"../checkpoints/stabilityai/stablelm-base-alpha-3b\")\n",
"predictions_file_name = '../data/predictions-stablelm-base.json'\n",
"\n",
"quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\", \"gptq.int4\"]] = None\n",
"max_new_tokens: int = 50\n",
"top_k: int = 200\n",
"temperature: float = 0.1\n",
"strategy: str = \"auto\"\n",
"devices: int = 1\n",
"precision: Optional[str] = None\n",
"num_samples: int = 1,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b842b257-2c9c-4c64-b011-f97e215c3794",
"metadata": {},
"outputs": [],
"source": [
"if strategy == \"fsdp\":\n",
" strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)\n",
"fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)\n",
"fabric.launch()\n",
"\n",
"check_valid_checkpoint_dir(checkpoint_dir)\n",
"\n",
"config = Config.from_json(checkpoint_dir / \"lit_config.json\")\n",
"\n",
"if quantize is not None and devices > 1:\n",
" raise NotImplementedError\n",
"if quantize == \"gptq.int4\":\n",
" model_file = \"lit_model_gptq.4bit.pth\"\n",
" if not (checkpoint_dir / model_file).is_file():\n",
" raise ValueError(\"Please run `python quantize/gptq.py` first\")\n",
"else:\n",
" model_file = \"lit_model.pth\"\n",
" \n",
"checkpoint_path = checkpoint_dir / model_file\n",
"tokenizer = Tokenizer(checkpoint_dir)\n",
"encoded = tokenizer.encode(prompt, device=fabric.device)\n",
"prompt_length = encoded.size(0)\n",
"max_returned_tokens = prompt_length + max_new_tokens\n",
"\n",
"fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n",
"t0 = time.perf_counter()\n",
"with fabric.init_module(empty_init=True), gptq_quantization(quantize == \"gptq.int4\"):\n",
" model = GPT(config)\n",
"fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n",
"with fabric.init_tensor():\n",
" # set the max_seq_length to limit the memory usage to what we need\n",
" model.max_seq_length = max_returned_tokens\n",
" # enable the kv cache\n",
" model.set_kv_cache(batch_size=1)\n",
"model.eval()\n",
"\n",
"model = fabric.setup_module(model)\n",
"t0 = time.perf_counter()\n",
"load_checkpoint(fabric, model, checkpoint_path)\n",
"fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59e00f69-d9fd-4cf5-a4ed-f142991b870f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3eeaa237-7a95-4c9d-aeda-8a3c246d02fe",
"metadata": {},
"outputs": [],
"source": [
"L.seed_everything(1234)\n",
"\n",
"t0 = time.perf_counter()\n",
"y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)\n",
"t = time.perf_counter() - t0\n",
"for block in model.transformer.h:\n",
" block.attn.kv_cache.reset_parameters()\n",
"output = tokenizer.decode(y)\n",
"fabric.print(output)\n",
"tokens_generated = y.size(0) - prompt_length"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "new",
"language": "python",
"name": "new"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}