"""Implmentation derived from lit-gpt package from Lightning Ai."""
import json
import sys
from pathlib import Path
## This script is modified from the original script from lit-gpt package from Lightning Ai.
## The original script is available at https://github.com/Lightning-AI/lit-gpt
## This script prepares the data for fine-tuning llms on entity extraction task
from typing import Optional
import requests
import torch
from torch.utils.data import random_split
from tqdm import tqdm
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from lit_gpt.tokenizer import Tokenizer
def prepare(
destination_path: Path = Path("data/entity_extraction"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
test_split_fraction: float = 0.05, # to get 10% test split
seed: int = 42,
mask_inputs: bool = False, # as in alpaca-lora
data_file_name: str = "entity-extraction-train-data.json",
data_file_url: str = "https://raw.githubusercontent.com/architkaila/Entity-Extraction-with-LLMs/main/data/entity_extraction/entity-extraction-train-data.json",
ignore_index: int = -1,
max_seq_length: Optional[int] = None,
) -> None:
"""Prepare custom dataset for medical entity extraction.
The output is a training and test dataset saved as `train.pt` and `test.pt`,
which stores the preprocessed and tokenized prompts and labels.
"""
if max_seq_length is None:
with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file:
config = json.load(file)
max_seq_length = config["block_size"]
destination_path.mkdir(parents=True, exist_ok=True)
data_file_path = destination_path / data_file_name
print("Loading data file...")
download_if_missing(data_file_path, data_file_url)
with open(data_file_path, "r", encoding="utf-8") as file:
data = json.load(file)
print("Loading tokenizer...")
tokenizer = Tokenizer(checkpoint_dir)
# Partition the dataset into train and test
train_set, test_set = random_split(
data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed)
)
train_set, test_set = list(train_set), list(test_set)
print(f"train has {len(train_set):,} samples")
print(f"test has {len(test_set):,} samples")
print("Processing train split ...")
train_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(train_set)
]
torch.save(train_set, destination_path / "train.pt")
print("Processing test split ...")
test_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(test_set)
]
torch.save(test_set, destination_path / "test.pt")
def download_if_missing(file_path: Path, file_url: str) -> None:
"""Downloads the raw json data file and saves it in the given destination."""
if file_path.exists() and file_path.stat().st_size > 0:
return
with open(file_path, "w", encoding="utf-8") as f:
f.write(requests.get(file_url).text)
def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict:
"""Processes a single sample.
Each sample in the dataset consists of:
- input: A string containing adverse event description (email)
- output: The response string with the entity extracted from the input
This function processes this data to produce a prompt text and a label for
supervised training. The prompt text is the input form the sample. The label/target is the same message but with the
output attached.
Finally, both the prompt and the label get tokenized. If desired, all tokens
in the label that correspond to the original input prompt get masked out (default).
"""
full_prompt = generate_prompt(example)
full_prompt_and_response = full_prompt + example["output"]
encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length)
encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length)
# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_full_prompt_and_response.clone()
if mask_inputs:
labels[: len(encoded_full_prompt)] = ignore_index
return {
**example,
"input_ids": encoded_full_prompt_and_response,
"input_ids_no_response": encoded_full_prompt,
"labels": labels,
}
def generate_prompt(example: dict) -> str:
"""Generates a standardized message to prompt the model with an input and a
'response' field."""
# Prepare the prompt for entity extraction
return (f"### Input:\n{example['input']}\n\n### Response:")
if __name__ == "__main__":
from jsonargparse import CLI
CLI(prepare)