In [1]:
import json
import sys
import time
from pathlib import Path
from typing import Literal, Optional

import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy

import os
## Add the lit_gpt folder to the path
sys.path.insert(0, os.path.abspath('../'))

from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.adapter_v2 import GPT, Block, Config
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
from scripts.prepare_entity_extraction_data import generate_prompt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.set_float32_matmul_precision("high")

In [3]:
with open('..data/entity_extraction/entity-extraction-test-data.json', 'r') as file:
    test_data = json.load(file)

In [4]:
sample = {
        "input": "Natalie Cooper,\nncooper@example.com\n6789 Birch Street, Denver, CO 80203,\n303-555-6543, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper",
        "output": "{\"drug_name\": \"Abilify\", \"adverse_events\": [\"nausea\", \"vomiting\"]}"
    }

In [5]:
## Choose one
model_type = 'stablelm'
#model_type = 'llama2'

In [8]:
input: str = sample["input"]
if model_type == "stablelm":
    print("[INFO] Using StableLM-3B Adapter Fine-tuned")
    adapter_path: Path = Path("../out/adapter_v2/Stable-LM/entity_extraction/lit_model_adapter_finetuned.pth")
    checkpoint_dir: Path = Path("../checkpoints/stabilityai/stablelm-base-alpha-3b")
    predictions_file_name = '../data/predictions-stablelm-adapter.json'

if model_type == "llama2":
    print("[INFO] Using LLaMa-2-7B  Adapter Fine-tuned")
    adapter_path: Path = Path("../out/adapter_v2/Llama-2/entity_extraction/lit_model_adapter_finetuned.pth")
    checkpoint_dir: Path = Path("../checkpoints/meta-llama/Llama-2-7b-hf")
    predictions_file_name = '../data/predictions-llama2-adapter.json'

quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None
max_new_tokens: int = 100
top_k: int = 200
temperature: float = 0.1
strategy: str = "auto"
devices: int = 1
precision: Optional[str] = None

StableLM 3B


In [9]:
if strategy == "fsdp":
    strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)
fabric.launch()

check_valid_checkpoint_dir(checkpoint_dir)

config = Config.from_json(checkpoint_dir / "lit_config.json")

if quantize is not None and devices > 1:
    raise NotImplementedError
if quantize == "gptq.int4":
    model_file = "lit_model_gptq.4bit.pth"
    if not (checkpoint_dir / model_file).is_file():
        raise ValueError("Please run `python quantize/gptq.py` first")
else:
    model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file

tokenizer = Tokenizer(checkpoint_dir)
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
    model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
    # set the max_seq_length to limit the memory usage to what we need
    model.max_seq_length = max_returned_tokens
    # enable the kv cache
    model.set_kv_cache(batch_size=1)
model.eval()

t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
adapter_checkpoint = lazy_load(adapter_path)
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint))
model.load_state_dict(checkpoint)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

model = fabric.setup(model)

Loading model 'checkpoints/stabilityai/stablelm-base-alpha-3b/lit_model.pth' with {'name': 'stablelm-base-alpha-3b', 'hf_config': {'org': 'stabilityai', 'name': 'stablelm-base-alpha-3b'}, 'block_size': 4096, 'vocab_size': 50254, 'padding_multiple': 512, 'padded_vocab_size': 50688, 'n_layer': 16, 'n_head': 32, 'n_embd': 4096, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 32, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 16384, 'rope_condense_ratio': 1, 'rope_base': 10000, 'adapter_prompt_length': 10, 'adapter_start_layer': 2, 'head_size': 128, 'rope_n_elem': 32}
Time to instantiate model: 0.07 seconds.
Time to load the model weights: 47.84 seconds.


In [11]:
prompt

"### Input:\nNatalie Cooper,\nncooper@example.com\n6789 Birch Street, Denver, CO 80203,\n303-555-6543, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper\n\n### Response:"

In [12]:
L.seed_everything(1234)
t0 = time.perf_counter()

y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0

output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)

tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
    fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

[rank: 0] Seed set to 1234


{"drug_name": "Abilify", "adverse_events": ["nausea", "vomiting"]}




Time for inference: 1.49 sec total, 18.77 tokens/sec
Memory used: 14.72 GB


In [13]:
output

'{"drug_name": "Abilify", "adverse_events": ["nausea", "vomiting"]}'

In [14]:
def parse_model_response(response):
    """
    Parse the model response to extract entities.

    Args:
    - response: A string representing the model's response.

    Returns:
    - A dictionary containing the extracted entities.
    """
    return json.loads(response)

In [15]:
test_data_with_prediction = []
for sample in test_data:
    # Generate prompt from sample
    prompt = generate_prompt(sample)
    fabric.print(prompt)
    
    # Encode the prompt
    encoded = tokenizer.encode(prompt, device=fabric.device)
    
    # Generate the prediction from the LLM
    y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
    output = tokenizer.decode(y)
    
    # Process the predicted completion
    output = output.split("### Response:")[1].strip()
    
    # Store prediction along with input and ground truth
    sample['prediction'] = output
    test_data_with_prediction.append(sample)
    
    fabric.print(output)
    fabric.print("---------------------------------------------------------\n\n")

### Input:
Natalie Cooper,
ncooper@example.com
6789 Birch Street, Denver, CO 80203,
303-555-6543, United States

Relationship to XYZ Pharma Inc.: Patient
Reason for contacting: Adverse Event

Message: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper

### Response:
{"drug_name": "Abilify", "adverse_events": ["nausea", "vomiting"]}
---------------------------------------------------------


### Input:
Mia Garcia
mia.garcia@email.com
321 Magnolia Drive, Dallas, TX 75201
555-890-1234, United States

Relationship to XYZ Pharma Inc.: Patient
Reason for contacting: Adverse Event

Message: I experienced a feeling of light-headedness and near-fainting after taking Staxyn for my erectile dysfunction. Is this a common side effect, and should I be worried?

### Response:
{"drug_name": "Staxyn", "adverse_events": ["feeling light-headed", "near-fainting"]}
------------------------------------

In [16]:
# Write the predictions data to a file
with open(predictions_file_name, 'w') as file:
    json.dump(test_data_with_prediction, file, indent=4)