[248dc9]: / notebooks / inference_lora.ipynb

Download this file

320 lines (319 with data), 10.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae13342-5979-42df-9fea-5bb97c60be23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "import time\n",
    "from pathlib import Path\n",
    "from typing import Literal, Optional\n",
    "\n",
    "import lightning as L\n",
    "import torch\n",
    "from lightning.fabric.plugins import BitsandbytesPrecision\n",
    "from lightning.fabric.strategies import FSDPStrategy\n",
    "\n",
    "import os\n",
    "## Add the lit_gpt folder to the path\n",
    "sys.path.insert(0, os.path.abspath('../'))\n",
    "\n",
    "from generate.base import generate\n",
    "from lit_gpt import Tokenizer\n",
    "from lit_gpt.lora import GPT, Block, Config, merge_lora_weights\n",
    "from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load\n",
    "from scripts.prepare_entity_extraction_data import generate_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac87e9a-6846-4f84-9c02-27f46fa417ce",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lora_r = 8\n",
    "lora_alpha = 16\n",
    "lora_dropout = 0.05\n",
    "lora_query = True\n",
    "lora_key = False\n",
    "lora_value = True\n",
    "lora_projection = False\n",
    "lora_mlp = False\n",
    "lora_head = False\n",
    "\n",
    "torch.set_float32_matmul_precision(\"high\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6eec0b2-81d9-4a88-b20c-0be57ec6b2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('..data/entity_extraction/entity-extraction-test-data.json', 'r') as file:\n",
    "    test_data = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d92ad07",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = {\n",
    "        \"input\": \"Natalie Cooper,\\nncooper@example.com\\n6789 Birch Street, Denver, CO 80203,\\n303-555-6543, United States\\n\\nRelationship to XYZ Pharma Inc.: Patient\\nReason for contacting: Adverse Event\\n\\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper\",\n",
    "        \"output\": \"{\\\"drug_name\\\": \\\"Abilify\\\", \\\"adverse_events\\\": [\\\"nausea\\\", \\\"vomiting\\\"]}\"\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7e65d14-07f2-4889-849f-277df3887c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Choose one\n",
    "#model_type = 'stablelm'\n",
    "model_type = 'llama2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3c30b09-e631-44b8-a45b-ac79dbbab52f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "input: str = sample[\"input\"]\n",
    "if model_type == \"stablelm\":\n",
    "    print(\"[INFO] Using StableLM-3B LoRA Fine-tuned\")\n",
    "    lora_path: Path = Path(\"../out/lora/Stable-LM/entity_extraction/lit_model_lora_finetuned.pth\")\n",
    "    checkpoint_dir: Path = Path(\"../checkpoints/stabilityai/stablelm-base-alpha-3b\")\n",
    "    predictions_file_name = '../data/predictions-stablelm-lora.json'\n",
    "\n",
    "if model_type == \"llama2\":\n",
    "    print(\"[INFO] Using LLaMa-2-7B  LoRA Fine-tuned\")\n",
    "    lora_path: Path = Path(\"../out/lora/Llama-2/entity_extraction/lit_model_lora_finetuned.pth\")\n",
    "    checkpoint_dir: Path = Path(\"../checkpoints/meta-llama/Llama-2-7b-hf\")\n",
    "    predictions_file_name = '../data/predictions-llama2-lora.json'\n",
    "\n",
    "quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\", \"gptq.int4\"]] = None\n",
    "max_new_tokens: int = 100\n",
    "top_k: int = 200\n",
    "temperature: float = 0.1\n",
    "strategy: str = \"auto\"\n",
    "devices: int = 1\n",
    "precision: Optional[str] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b842b257-2c9c-4c64-b011-f97e215c3794",
   "metadata": {},
   "outputs": [],
   "source": [
    "if strategy == \"fsdp\":\n",
    "    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)\n",
    "fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)\n",
    "fabric.launch()\n",
    "\n",
    "check_valid_checkpoint_dir(checkpoint_dir)\n",
    "\n",
    "config = Config.from_json(\n",
    "    checkpoint_dir / \"lit_config.json\",\n",
    "    r=lora_r,\n",
    "    alpha=lora_alpha,\n",
    "    dropout=lora_dropout,\n",
    "    to_query=lora_query,\n",
    "    to_key=lora_key,\n",
    "    to_value=lora_value,\n",
    "    to_projection=lora_projection,\n",
    "    to_mlp=lora_mlp,\n",
    "    to_head=lora_head,\n",
    ")\n",
    "\n",
    "if quantize is not None and devices > 1:\n",
    "    raise NotImplementedError\n",
    "if quantize == \"gptq.int4\":\n",
    "    model_file = \"lit_model_gptq.4bit.pth\"\n",
    "    if not (checkpoint_dir / model_file).is_file():\n",
    "        raise ValueError(\"Please run `python quantize/gptq.py` first\")\n",
    "else:\n",
    "    model_file = \"lit_model.pth\"\n",
    "checkpoint_path = checkpoint_dir / model_file\n",
    "\n",
    "tokenizer = Tokenizer(checkpoint_dir)\n",
    "prompt = generate_prompt(sample)\n",
    "encoded = tokenizer.encode(prompt, device=fabric.device)\n",
    "prompt_length = encoded.size(0)\n",
    "max_returned_tokens = prompt_length + max_new_tokens\n",
    "\n",
    "fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n",
    "t0 = time.perf_counter()\n",
    "with fabric.init_module(empty_init=True), gptq_quantization(quantize == \"gptq.int4\"):\n",
    "    model = GPT(config)\n",
    "fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n",
    "with fabric.init_tensor():\n",
    "    # set the max_seq_length to limit the memory usage to what we need\n",
    "    model.max_seq_length = max_returned_tokens\n",
    "    # enable the kv cache\n",
    "    model.set_kv_cache(batch_size=1)\n",
    "model.eval()\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "checkpoint = lazy_load(checkpoint_path)\n",
    "lora_checkpoint = lazy_load(lora_path)\n",
    "checkpoint.update(lora_checkpoint.get(\"model\", lora_checkpoint))\n",
    "\n",
    "model.load_state_dict(checkpoint)\n",
    "fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n",
    "\n",
    "merge_lora_weights(model)\n",
    "model = fabric.setup(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e00f69-d9fd-4cf5-a4ed-f142991b870f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eeaa237-7a95-4c9d-aeda-8a3c246d02fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "L.seed_everything(1234)\n",
    "t0 = time.perf_counter()\n",
    "\n",
    "y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)\n",
    "t = time.perf_counter() - t0\n",
    "\n",
    "output = tokenizer.decode(y)\n",
    "output = output.split(\"### Response:\")[1].strip()\n",
    "fabric.print(output)\n",
    "\n",
    "tokens_generated = y.size(0) - prompt_length\n",
    "fabric.print(f\"\\n\\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr)\n",
    "if fabric.device.type == \"cuda\":\n",
    "    fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d286fb-d595-4d8c-94a1-06ec7357722f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_model_response(response):\n",
    "    \"\"\"\n",
    "    Parse the model response to extract entities.\n",
    "\n",
    "    Args:\n",
    "    - response: A string representing the model's response.\n",
    "\n",
    "    Returns:\n",
    "    - A dictionary containing the extracted entities.\n",
    "    \"\"\"\n",
    "    return json.loads(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faee8a5e-9e3c-4d95-b351-d19a46c3f0a2",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_data_with_prediction = []\n",
    "for sample in test_data:\n",
    "    # Generate prompt from sample\n",
    "    prompt = generate_prompt(sample)\n",
    "    fabric.print(prompt)\n",
    "    \n",
    "    # Encode the prompt\n",
    "    encoded = tokenizer.encode(prompt, device=fabric.device)\n",
    "    \n",
    "    # Generate the prediction from the LLM\n",
    "    y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)\n",
    "    output = tokenizer.decode(y)\n",
    "    \n",
    "    # Process the predicted completion\n",
    "    output = output.split(\"### Response:\")[1].strip()\n",
    "    \n",
    "    # Store prediction along with input and ground truth\n",
    "    sample['prediction'] = output\n",
    "    test_data_with_prediction.append(sample)\n",
    "    \n",
    "    fabric.print(output)\n",
    "    fabric.print(\"---------------------------------------------------------\\n\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1616278a-0dca-48b7-b254-f1bd957942ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write the predictions data to a file\n",
    "with open(predictions_file_name, 'w') as file:\n",
    "    json.dump(test_data_with_prediction, file, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047cb010-58dd-4913-a0e5-fe167d32c320",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "new",
   "language": "python",
   "name": "new"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}