In [None]:
import json
import sys
import time
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy

import os
## Add the lit_gpt folder to the path
sys.path.insert(0, os.path.abspath('../'))

from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.lora import GPT, Block, Config, merge_lora_weights
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
from scripts.prepare_entity_extraction_data import generate_prompt

In [None]:
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_query = True
lora_key = False
lora_value = True
lora_projection = False
lora_mlp = False
lora_head = False

torch.set_float32_matmul_precision("high")

In [None]:
with open('..data/entity_extraction/entity-extraction-test-data.json', 'r') as file:
    test_data = json.load(file)

In [None]:
sample = {
        "input": "Natalie Cooper,\nncooper@example.com\n6789 Birch Street, Denver, CO 80203,\n303-555-6543, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper",
        "output": "{\"drug_name\": \"Abilify\", \"adverse_events\": [\"nausea\", \"vomiting\"]}"
    }

In [None]:
## Choose one
#model_type = 'stablelm'
model_type = 'llama2'

In [None]:
input: str = sample["input"]
if model_type == "stablelm":
    print("[INFO] Using StableLM-3B LoRA Fine-tuned")
    lora_path: Path = Path("../out/lora/Stable-LM/entity_extraction/lit_model_lora_finetuned.pth")
    checkpoint_dir: Path = Path("../checkpoints/stabilityai/stablelm-base-alpha-3b")
    predictions_file_name = '../data/predictions-stablelm-lora.json'

if model_type == "llama2":
    print("[INFO] Using LLaMa-2-7B  LoRA Fine-tuned")
    lora_path: Path = Path("../out/lora/Llama-2/entity_extraction/lit_model_lora_finetuned.pth")
    checkpoint_dir: Path = Path("../checkpoints/meta-llama/Llama-2-7b-hf")
    predictions_file_name = '../data/predictions-llama2-lora.json'

quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None
max_new_tokens: int = 100
top_k: int = 200
temperature: float = 0.1
strategy: str = "auto"
devices: int = 1
precision: Optional[str] = None

In [None]:
if strategy == "fsdp":
    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(
    checkpoint_dir / "lit_config.json",
    r=lora_r,
    alpha=lora_alpha,
    dropout=lora_dropout,
    to_query=lora_query,
    to_key=lora_key,
    to_value=lora_value,
    to_projection=lora_projection,
    to_mlp=lora_mlp,
    to_head=lora_head,
)

if quantize is not None and devices > 1:
    raise NotImplementedError
if quantize == "gptq.int4":
    model_file = "lit_model_gptq.4bit.pth"
    if not (checkpoint_dir / model_file).is_file():
        raise ValueError("Please run `python quantize/gptq.py` first")
else:
    model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file

tokenizer = Tokenizer(checkpoint_dir)
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
    model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
    # set the max_seq_length to limit the memory usage to what we need
    model.max_seq_length = max_returned_tokens
    # enable the kv cache
    model.set_kv_cache(batch_size=1)
model.eval()

t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
lora_checkpoint = lazy_load(lora_path)
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))

model.load_state_dict(checkpoint)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

merge_lora_weights(model)
model = fabric.setup(model)

In [None]:
prompt

In [None]:
L.seed_everything(1234)
t0 = time.perf_counter()

y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)

tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
    fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

In [None]:
def parse_model_response(response):
    """
    Parse the model response to extract entities.

    Args:
    - response: A string representing the model's response.

    Returns:
    - A dictionary containing the extracted entities.
    """
    return json.loads(response)

In [None]:
test_data_with_prediction = []
for sample in test_data:
    # Generate prompt from sample
    prompt = generate_prompt(sample)
    fabric.print(prompt)
    
    # Encode the prompt
    encoded = tokenizer.encode(prompt, device=fabric.device)
    
    # Generate the prediction from the LLM
    y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
    output = tokenizer.decode(y)
    
    # Process the predicted completion
    output = output.split("### Response:")[1].strip()
    
    # Store prediction along with input and ground truth
    sample['prediction'] = output
    test_data_with_prediction.append(sample)
    
    fabric.print(output)
    fabric.print("---------------------------------------------------------\n\n")

In [None]:
# Write the predictions data to a file
with open(predictions_file_name, 'w') as file:
    json.dump(test_data_with_prediction, file, indent=4)