[248dc9]: / notebooks / inference_base.ipynb

Download this file

309 lines (308 with data), 11.4 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae13342-5979-42df-9fea-5bb97c60be23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import time\n",
    "from pathlib import Path\n",
    "from typing import Any, Literal, Optional\n",
    "import json\n",
    "\n",
    "import lightning as L\n",
    "import torch\n",
    "import torch._dynamo.config\n",
    "import torch._inductor.config\n",
    "from lightning.fabric.plugins import BitsandbytesPrecision\n",
    "from lightning.fabric.strategies import FSDPStrategy\n",
    "\n",
    "## Add the lit_gpt folder to the path\n",
    "sys.path.insert(0, os.path.abspath('../'))\n",
    "\n",
    "from lit_gpt import GPT, Config, Tokenizer\n",
    "from lit_gpt.model import Block\n",
    "from lit_gpt.utils import (\n",
    "    check_valid_checkpoint_dir,\n",
    "    get_default_supported_precision,\n",
    "    gptq_quantization,\n",
    "    load_checkpoint,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac87e9a-6846-4f84-9c02-27f46fa417ce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.set_float32_matmul_precision(\"high\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6eec0b2-81d9-4a88-b20c-0be57ec6b2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/entity_extraction/entity-extraction-test-data.json', 'r') as file:\n",
    "    test_data = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d92ad07",
   "metadata": {},
   "outputs": [],
   "source": [
    "example = {\n",
    "        \"input\": \"Natalie Cooper,\\nncooper@example.com\\n6789 Birch Street, Denver, CO 80203,\\n303-555-6543, United States\\n\\nRelationship to XYZ Pharma Inc.: Patient\\nReason for contacting: Adverse Event\\n\\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper\",\n",
    "        \"output\": \"{\\\"drug_name\\\": \\\"Abilify\\\", \\\"adverse_events\\\": [\\\"nausea\\\", \\\"vomiting\\\"]}\"\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5eed73-7020-466a-8b83-5ebc7d497698",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = f\"\"\"Act as an expert Analyst with 20+ years of experience\\\n",
    "in Pharma and Healthcare industry. \\\n",
    "For the following provided input you need to generate the output which \\\n",
    "identifies and extracts entities like 'drug_name' and 'adverse_events' \\\n",
    "use the format:\\n\\\n",
    "{{'drug_name':'DRUG_NAME_HERE', 'adverse_events':[## List of symptoms here]}}\\n\\\n",
    "\n",
    "### Extract Entities from the follwing:\\n\\\n",
    "{example[\"input\"]}\\\n",
    "\n",
    "### Response:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c972c1d6-f2f3-4bf8-bd92-405bda4b02ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae2d36b-3f58-4be1-955d-cca8c8c5f7d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:\n",
    "    if torch._dynamo.is_compiling():\n",
    "        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly\n",
    "        distribution = torch.empty_like(probs).exponential_(1)\n",
    "        return torch.argmax(probs / distribution, dim=-1, keepdim=True)\n",
    "    return torch.multinomial(probs, num_samples=1)\n",
    "\n",
    "\n",
    "def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:\n",
    "    logits = logits[0, -1]\n",
    "    # optionally crop the logits to only the top k options\n",
    "    if top_k is not None:\n",
    "        v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n",
    "        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n",
    "        logits = torch.full_like(logits, float(\"-inf\")).scatter_(-1, i, v)\n",
    "    # optionally scale the logits and sample from a probability distribution\n",
    "    if temperature > 0.0:\n",
    "        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)\n",
    "        return multinomial_num_samples_1(probs)\n",
    "    return torch.argmax(logits, dim=-1, keepdim=True)\n",
    "\n",
    "def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:\n",
    "    logits = model(x, input_pos)\n",
    "    next = sample(logits, **kwargs)\n",
    "    return next.type_as(x)\n",
    "\n",
    "@torch.inference_mode()\n",
    "def generate(\n",
    "    model: GPT,\n",
    "    prompt: torch.Tensor,\n",
    "    max_returned_tokens: int,\n",
    "    *,\n",
    "    temperature: float = 1.0,\n",
    "    top_k: Optional[int] = None,\n",
    "    eos_id: Optional[int] = None,\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n",
    "\n",
    "    The implementation of this function is modified from A. Karpathy's nanoGPT.\n",
    "\n",
    "    Args:\n",
    "        model: The model to use.\n",
    "        prompt: Tensor of shape (T) with indices of the prompt sequence.\n",
    "        max_returned_tokens: The maximum number of tokens to return (given plus generated).\n",
    "        temperature: Scales the predicted logits by 1 / temperature.\n",
    "        top_k: If specified, only sample among the tokens with the k highest probabilities.\n",
    "        eos_id: If specified, stop generating any more token once the <eos> token is triggered.\n",
    "    \"\"\"\n",
    "    T = prompt.size(0)\n",
    "    assert max_returned_tokens > T\n",
    "    if model.max_seq_length < max_returned_tokens - 1:\n",
    "        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a\n",
    "        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do\n",
    "        # not support it to avoid negatively impacting the overall speed\n",
    "        raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n",
    "\n",
    "    device = prompt.device\n",
    "    tokens = [prompt]\n",
    "    input_pos = torch.tensor([T], device=device)\n",
    "    token = next_token(\n",
    "        model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k\n",
    "    ).clone()\n",
    "    tokens.append(token)\n",
    "    for _ in range(2, max_returned_tokens - T + 1):\n",
    "        token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone()\n",
    "        tokens.append(token)\n",
    "        if token == eos_id:\n",
    "            break\n",
    "        input_pos = input_pos.add_(1)\n",
    "    return torch.cat(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3c30b09-e631-44b8-a45b-ac79dbbab52f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(\"[INFO] Using StableLM-3B base model\")\n",
    "checkpoint_dir: Path = Path(\"../checkpoints/stabilityai/stablelm-base-alpha-3b\")\n",
    "predictions_file_name = '../data/predictions-stablelm-base.json'\n",
    "\n",
    "quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\", \"gptq.int4\"]] = None\n",
    "max_new_tokens: int = 50\n",
    "top_k: int = 200\n",
    "temperature: float = 0.1\n",
    "strategy: str = \"auto\"\n",
    "devices: int = 1\n",
    "precision: Optional[str] = None\n",
    "num_samples: int = 1,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b842b257-2c9c-4c64-b011-f97e215c3794",
   "metadata": {},
   "outputs": [],
   "source": [
    "if strategy == \"fsdp\":\n",
    "    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)\n",
    "fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)\n",
    "fabric.launch()\n",
    "\n",
    "check_valid_checkpoint_dir(checkpoint_dir)\n",
    "\n",
    "config = Config.from_json(checkpoint_dir / \"lit_config.json\")\n",
    "\n",
    "if quantize is not None and devices > 1:\n",
    "    raise NotImplementedError\n",
    "if quantize == \"gptq.int4\":\n",
    "    model_file = \"lit_model_gptq.4bit.pth\"\n",
    "    if not (checkpoint_dir / model_file).is_file():\n",
    "        raise ValueError(\"Please run `python quantize/gptq.py` first\")\n",
    "else:\n",
    "    model_file = \"lit_model.pth\"\n",
    "    \n",
    "checkpoint_path = checkpoint_dir / model_file\n",
    "tokenizer = Tokenizer(checkpoint_dir)\n",
    "encoded = tokenizer.encode(prompt, device=fabric.device)\n",
    "prompt_length = encoded.size(0)\n",
    "max_returned_tokens = prompt_length + max_new_tokens\n",
    "\n",
    "fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n",
    "t0 = time.perf_counter()\n",
    "with fabric.init_module(empty_init=True), gptq_quantization(quantize == \"gptq.int4\"):\n",
    "    model = GPT(config)\n",
    "fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n",
    "with fabric.init_tensor():\n",
    "    # set the max_seq_length to limit the memory usage to what we need\n",
    "    model.max_seq_length = max_returned_tokens\n",
    "    # enable the kv cache\n",
    "    model.set_kv_cache(batch_size=1)\n",
    "model.eval()\n",
    "\n",
    "model = fabric.setup_module(model)\n",
    "t0 = time.perf_counter()\n",
    "load_checkpoint(fabric, model, checkpoint_path)\n",
    "fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e00f69-d9fd-4cf5-a4ed-f142991b870f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eeaa237-7a95-4c9d-aeda-8a3c246d02fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "L.seed_everything(1234)\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)\n",
    "t = time.perf_counter() - t0\n",
    "for block in model.transformer.h:\n",
    "    block.attn.kv_cache.reset_parameters()\n",
    "output = tokenizer.decode(y)\n",
    "fabric.print(output)\n",
    "tokens_generated = y.size(0) - prompt_length"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "new",
   "language": "python",
   "name": "new"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}