|
a |
|
b/generate/inference_adapter.py |
|
|
1 |
# This script is used to generate predictions using the fine-tuned adapter models |
|
|
2 |
# This script is modified from the original script provided by the LIT team: https://github.com/Lightning-AI/lit-gpt |
|
|
3 |
|
|
|
4 |
## Usage: |
|
|
5 |
# python generate/inference_adapter.py --model-type "stablelm" --input-file "..data/entity_extraction/entity-extraction-test-data.json" |
|
|
6 |
# python generate/inference_adapter.py --model-type "llama2" --input-file "..data/entity_extraction/entity-extraction-test-data.json" |
|
|
7 |
|
|
|
8 |
import argparse |
|
|
9 |
import json |
|
|
10 |
import sys |
|
|
11 |
import time |
|
|
12 |
from pathlib import Path |
|
|
13 |
from typing import Literal, Optional |
|
|
14 |
|
|
|
15 |
import lightning as L |
|
|
16 |
import torch |
|
|
17 |
from lightning.fabric.plugins import BitsandbytesPrecision |
|
|
18 |
from lightning.fabric.strategies import FSDPStrategy |
|
|
19 |
|
|
|
20 |
import os |
|
|
21 |
## Add the lit_gpt folder to the path |
|
|
22 |
sys.path.insert(0, os.path.abspath('../')) |
|
|
23 |
|
|
|
24 |
from base import generate |
|
|
25 |
from lit_gpt import Tokenizer |
|
|
26 |
from lit_gpt.adapter_v2 import GPT, Block, Config |
|
|
27 |
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load |
|
|
28 |
from scripts.prepare_entity_extraction_data import generate_prompt |
|
|
29 |
|
|
|
30 |
def generate_prediction(model_type, sample): |
|
|
31 |
""" |
|
|
32 |
This function is used to generate predictions using the fine-tuned adapter models. It loads the model |
|
|
33 |
and generates and prints a sample prediction. Further, it generates predictions for all the samples |
|
|
34 |
in the test data and stores the predictions in a file. |
|
|
35 |
|
|
|
36 |
Args: |
|
|
37 |
model_type (str): The type of model to use for prediction |
|
|
38 |
sample (dict): The sample for which the prediction is to be generated |
|
|
39 |
|
|
|
40 |
Returns: |
|
|
41 |
None |
|
|
42 |
""" |
|
|
43 |
|
|
|
44 |
# Check which model to use for prediction |
|
|
45 |
if model_type == "stablelm": |
|
|
46 |
print("[INFO] Using StableLM-3B Adapter Fine-tuned") |
|
|
47 |
adapter_path: Path = Path("out/adapter_v2/Stable-LM/entity_extraction/lit_model_adapter_finetuned.pth") |
|
|
48 |
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b") |
|
|
49 |
predictions_file_name = 'data/predictions-stablelm-adapter.json' |
|
|
50 |
|
|
|
51 |
if model_type == "llama2": |
|
|
52 |
print("[INFO] Using LLaMa-2-7B Adapter Fine-tuned") |
|
|
53 |
adapter_path: Path = Path("out/adapter_v2/Llama-2/entity_extraction/lit_model_adapter_finetuned.pth") |
|
|
54 |
checkpoint_dir: Path = Path("checkpoints/meta-llama/Llama-2-7b-hf") |
|
|
55 |
predictions_file_name = 'data/predictions-llama2-adapter.json' |
|
|
56 |
|
|
|
57 |
# Set the model parameters |
|
|
58 |
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None |
|
|
59 |
max_new_tokens: int = 100 |
|
|
60 |
top_k: int = 200 |
|
|
61 |
temperature: float = 0.1 |
|
|
62 |
strategy: str = "auto" |
|
|
63 |
devices: int = 1 |
|
|
64 |
precision: Optional[str] = None |
|
|
65 |
|
|
|
66 |
# Set the strategy |
|
|
67 |
if strategy == "fsdp": |
|
|
68 |
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) |
|
|
69 |
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy) |
|
|
70 |
fabric.launch() |
|
|
71 |
|
|
|
72 |
# Check if the checkpoint directory is valid and load the model configuration |
|
|
73 |
check_valid_checkpoint_dir(checkpoint_dir) |
|
|
74 |
config = Config.from_json(checkpoint_dir / "lit_config.json") |
|
|
75 |
|
|
|
76 |
# Check if the quantization is required |
|
|
77 |
if quantize is not None and devices > 1: |
|
|
78 |
raise NotImplementedError |
|
|
79 |
if quantize == "gptq.int4": |
|
|
80 |
model_file = "lit_model_gptq.4bit.pth" |
|
|
81 |
if not (checkpoint_dir / model_file).is_file(): |
|
|
82 |
raise ValueError("Please run `python quantize/gptq.py` first") |
|
|
83 |
else: |
|
|
84 |
model_file = "lit_model.pth" |
|
|
85 |
|
|
|
86 |
# Load the model from the checkpoint |
|
|
87 |
checkpoint_path = checkpoint_dir / model_file |
|
|
88 |
|
|
|
89 |
# Load the tokenizer |
|
|
90 |
tokenizer = Tokenizer(checkpoint_dir) |
|
|
91 |
|
|
|
92 |
# Generate the prompt from the given sample and encode it |
|
|
93 |
prompt = generate_prompt(sample) |
|
|
94 |
encoded = tokenizer.encode(prompt, device=fabric.device) |
|
|
95 |
|
|
|
96 |
# Set the max sequence length |
|
|
97 |
prompt_length = encoded.size(0) |
|
|
98 |
max_returned_tokens = prompt_length + max_new_tokens |
|
|
99 |
|
|
|
100 |
# Load the model configuration |
|
|
101 |
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) |
|
|
102 |
t0 = time.perf_counter() |
|
|
103 |
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"): |
|
|
104 |
model = GPT(config) |
|
|
105 |
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) |
|
|
106 |
with fabric.init_tensor(): |
|
|
107 |
# set the max_seq_length to limit the memory usage to what we need |
|
|
108 |
model.max_seq_length = max_returned_tokens |
|
|
109 |
# enable the kv cache |
|
|
110 |
model.set_kv_cache(batch_size=1) |
|
|
111 |
model.eval() |
|
|
112 |
|
|
|
113 |
# Load the model weights and setup the adapter |
|
|
114 |
t0 = time.perf_counter() |
|
|
115 |
checkpoint = lazy_load(checkpoint_path) |
|
|
116 |
adapter_checkpoint = lazy_load(adapter_path) |
|
|
117 |
checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) |
|
|
118 |
model.load_state_dict(checkpoint) |
|
|
119 |
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) |
|
|
120 |
model = fabric.setup(model) |
|
|
121 |
|
|
|
122 |
# Set the seed and generate the prediction |
|
|
123 |
L.seed_everything(1234) |
|
|
124 |
t0 = time.perf_counter() |
|
|
125 |
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) |
|
|
126 |
t = time.perf_counter() - t0 |
|
|
127 |
|
|
|
128 |
# Process the predicted completion |
|
|
129 |
output = tokenizer.decode(y) |
|
|
130 |
output = output.split("### Response:")[1].strip() |
|
|
131 |
fabric.print(output) |
|
|
132 |
tokens_generated = y.size(0) - prompt_length |
|
|
133 |
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr) |
|
|
134 |
if fabric.device.type == "cuda": |
|
|
135 |
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) |
|
|
136 |
|
|
|
137 |
# Generate predictions for all the samples in the test data |
|
|
138 |
test_data_with_prediction = [] |
|
|
139 |
for sample in test_data: |
|
|
140 |
# Generate prompt from sample |
|
|
141 |
prompt = generate_prompt(sample) |
|
|
142 |
fabric.print(prompt) |
|
|
143 |
|
|
|
144 |
# Encode the prompt |
|
|
145 |
encoded = tokenizer.encode(prompt, device=fabric.device) |
|
|
146 |
|
|
|
147 |
# Generate the prediction from the LLM |
|
|
148 |
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) |
|
|
149 |
output = tokenizer.decode(y) |
|
|
150 |
|
|
|
151 |
# Process the predicted completion |
|
|
152 |
output = output.split("### Response:")[1].strip() |
|
|
153 |
|
|
|
154 |
# Store prediction along with input and ground truth |
|
|
155 |
sample['prediction'] = output |
|
|
156 |
test_data_with_prediction.append(sample) |
|
|
157 |
|
|
|
158 |
fabric.print(output) |
|
|
159 |
fabric.print("---------------------------------------------------------\n\n") |
|
|
160 |
|
|
|
161 |
# Write the predictions data to a file |
|
|
162 |
with open(predictions_file_name, 'w') as file: |
|
|
163 |
json.dump(test_data_with_prediction, file, indent=4) |
|
|
164 |
|
|
|
165 |
if __name__ == "__main__": |
|
|
166 |
torch.set_float32_matmul_precision("high") |
|
|
167 |
|
|
|
168 |
# Parse the arguments |
|
|
169 |
parser = argparse.ArgumentParser(description="Entity Extraction Script") |
|
|
170 |
parser.add_argument('--input-file', type=str, default='data/entity_extraction/entity-extraction-test-data.json', help="Path to the test JSON file") |
|
|
171 |
parser.add_argument('--model-type', type=str, choices=['stablelm', 'llama2'], default='stablelm', help="Type of model to use (stablelm or llama2)") |
|
|
172 |
args = parser.parse_args() |
|
|
173 |
|
|
|
174 |
# Load the test data |
|
|
175 |
with open(args.input_file, 'r') as file: |
|
|
176 |
test_data = json.load(file) |
|
|
177 |
|
|
|
178 |
# Single Sample |
|
|
179 |
sample = { |
|
|
180 |
"input": "Natalie Cooper,\nncooper@example.com\n6789 Birch Street, Denver, CO 80203,\n303-555-6543, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: Hi, after starting Abilify for bipolar I disorder, I've noticed that I am experiencing nausea and vomiting. Are these typical reactions? Best, Natalie Cooper", |
|
|
181 |
"output": "{\"drug_name\": \"Abilify\", \"adverse_events\": [\"nausea\", \"vomiting\"]}" |
|
|
182 |
} |
|
|
183 |
|
|
|
184 |
generate_prediction(model_type=args.model_type, sample=sample) |