[133ee8]: / ATML_part2.ipynb

Download this file

1 lines (1 with data), 59.4 kB

{"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"accelerator":"GPU","colab":{"collapsed_sections":["XKL46PcIc21t","PJljpW-Jc37a","f4wIndWS2lyy","-BMh1vbuc5qp","a_6XDhvKc37b","vR87HRbfce2i","4MovmBY3c3Rm","5kca4Wt6dGMb","uJGGWVExc37f","3QSzEjW2MZrc"],"gpuType":"T4","include_colab_link":true,"provenance":[]},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"sourceId":218074,"sourceType":"modelInstanceVersion","isSourceIdPinned":true,"modelInstanceId":185968,"modelId":208088}],"dockerImageVersionId":30823,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"<a href=\"https://colab.research.google.com/github/sAndreotti/MedicalMeadow/blob/main/ATML_part2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>","metadata":{"colab_type":"text","id":"view-in-github"}},{"cell_type":"markdown","source":"# MedicalMeadow ChatBot fine-tuning Llama 3.2 1B\nMedicalMeadow is a project focused on training a chatbot using the LLaMA model, fine-tuned with the Medical Meadow dataset. The aim is to develop a robust NLP system capable of answering medical questions effectively.","metadata":{}},{"cell_type":"markdown","source":"## Libraries","metadata":{"id":"29nxqgohSsQ3"}},{"cell_type":"markdown","source":"### Download","metadata":{"id":"K2eUNhk9tprU"}},{"cell_type":"code","source":"!pip install datasets accelerate peft transformers trl==0.12.0 plotly huggingface_hub\n!pip install --upgrade smart_open\n!pip install --upgrade gensim\n!pip install ffmpeg-python\n!pip install -U openai-whisper\n!pip install scipy librosa unidecode inflect\n!pip install unsloth\n!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git\n!pip install TTS\n!pip uninstall -y bitsandbytes\n!pip install bitsandbytes\n!pip install nltk\n!pip install python-dotenv","metadata":{"id":"SAho3HGib9-U","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Import","metadata":{"id":"8b63DJghtprV"}},{"cell_type":"code","source":"from datasets import load_dataset\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom collections import Counter\nfrom trl import SFTTrainer\nimport re\nfrom gensim.models.word2vec import Word2Vec\nimport plotly.express as px\nimport random\nfrom sklearn.manifold import TSNE\nfrom torch.utils.data import Dataset\nfrom torch.utils.data import random_split\nfrom torch.utils.data import Subset\nfrom peft import prepare_model_for_kbit_training, LoraConfig\nimport torch\nfrom transformers import (\n    AutoTokenizer,\n    AutoModelForCausalLM,\n    BitsAndBytesConfig,\n    TrainingArguments,\n    DataCollatorForSeq2Seq,\n    AutoModelForSpeechSeq2Seq,\n    AutoProcessor,\n    pipeline,\n    TextStreamer\n)\nfrom peft import AutoPeftModelForCausalLM\nimport pandas as pd\nfrom wordcloud import WordCloud\nimport nltk\nfrom nltk.translate.bleu_score import sentence_bleu\nfrom tabulate import tabulate\n\nfrom unsloth import FastLanguageModel\nfrom unsloth.chat_templates import get_chat_template\nfrom unsloth.chat_templates import train_on_responses_only\nfrom TTS.api import TTS\nfrom dotenv import load_dotenv\nfrom huggingface_hub import login\nimport os","metadata":{"id":"q_JimYqjjY4S","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Import for audio section","metadata":{"id":"E2DDMWKGSsQ5"}},{"cell_type":"code","source":"from IPython.display import HTML, Audio\nfrom google.colab.output import eval_js\nfrom base64 import b64decode\nfrom scipy.io.wavfile import read as wav_read\nimport io\nimport ffmpeg\nimport scipy\nimport whisper\n\nimport librosa\nimport soundfile as sf","metadata":{"id":"lkjSoQT4SsQ6","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Hugging Face settings","metadata":{"id":"zcC-cug1tprW"}},{"cell_type":"markdown","source":"In order to carry out operations with Llama and to have authorisations to use it, it is necessary to log in to Hugging Face and make a request to use the template. The HF access token must be saved in an .env file","metadata":{}},{"cell_type":"code","source":"load_dotenv()\nlogin(token=os.environ.get('HF_TOKEN'))","metadata":{"id":"40cASCEmtprW","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Investigate Dataset","metadata":{"id":"XKL46PcIc21t"}},{"cell_type":"markdown","source":"This dataset, [Medical Meadow](https://huggingface.co/datasets/medalpaca/medical_meadow_medical_flashcards?row=0), includes a deep understanding of basic medical sciences, clinical knowledge, and clinical skills. The embedded flashcards are created and updated by medical students and cover the entirety of this curriculum, addressing subjects such as anatomy, physiology, pathology, pharmacology, and more. These flashcards frequently feature succinct summaries and mnemonics to aid in learning and retention of vital medical concepts.","metadata":{}},{"cell_type":"markdown","source":"In this section, we analyse the dataset under consideration","metadata":{}},{"cell_type":"code","source":"# Download Medical Meadow dataset from HF\nds = load_dataset(\"medalpaca/medical_meadow_medical_flashcards\")\n\n# The dataset is composed by just on split: 'train'\nds = ds['train']\nds","metadata":{"id":"YltBm7i4b0IM","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Analyze the dataset: input, output, instruction\nprint(ds.features, \"\\n\")\nprint(\"Instruction:\")\nprint(f\"length: {len(ds['instruction'])}\")\nprint(f\"example: {ds['instruction'][0]} \\n\")\n\nprint(f\"Input:\")\nprint(f\"length: {len(ds['input'])}\")\nprint(f\"example: {ds['input'][0]} \\n\")\n\nprint(f\"Output:\")\nprint(f\"length: {len(ds['output'])}\")\nprint(f\"example: {ds['output'][0]} \\n\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Some plots about the dataset","metadata":{"id":"a9SXu1UFoE4x"}},{"cell_type":"code","source":"# Better usability\ninstructions = ds['instruction']\ninput_phrases = ds['input']\noutput_phrases = ds['output']","metadata":{"id":"nLEtHKtISsQ7","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"Plot about the ditribution of instruction's classes (just 1 class for all instances)","metadata":{"id":"TWAxDZ05j1Pj"}},{"cell_type":"code","source":"%matplotlib inline","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Count the frequency of each unique instruction\ninstruction_counts = {instruction: instructions.count(instruction) for instruction in set(instructions)}\n\n# Sort the instructions by frequency\nsorted_instructions = sorted(instruction_counts.items(), key=lambda x: x[1], reverse=True)\n\n# Separate the instructions and their counts for plotting\nsorted_instruction_names = [item[0] for item in sorted_instructions]\nsorted_instruction_counts = [item[1] for item in sorted_instructions]\n\n# Plotting the frequency of instructions\nplt.figure(figsize=(10, 5))\n\nbars = plt.barh(sorted_instruction_names, sorted_instruction_counts, color='skyblue', edgecolor='black', linewidth=1.2)\nplt.title('Instruction Frequency Distribution')\nplt.xlabel('Frequency')\nplt.ylabel('Instruction')\n\n# Show the plot\nplt.tight_layout()\nplt.savefig(\"instruction_frequency_distribution.png\")\nplt.show()","metadata":{"id":"4TAMV5DdnRg7","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"Two plots about the distribution of the lengths of the instances of the input and output sentences","metadata":{"id":"auVHyN0rj-fO"}},{"cell_type":"code","source":"# Calculate the length of each phrase\ninput_lengths = [len(phrase) for phrase in input_phrases]\noutput_lengths = [len(phrase) for phrase in output_phrases]\n\n# Define the bins for the length ranges\nmax_input = max(input_lengths)\nmax_output = max(output_lengths)\n\ninput_bins = [i * max_input / 10 for i in range(1, 11)]\noutput_bins = [i * max_output / 10 for i in range(1, 11)]\nbin_labels_input = [f'{int(input_bins[i-1])}-{int(input_bins[i])}' for i in range(1, 10)]\nbin_labels_output = [f'{int(output_bins[i-1])}-{int(output_bins[i])}' for i in range(1, 10)]\n\n# Bin the lengths into the categories\ninput_binned = np.digitize(input_lengths, input_bins)  # Categorize based on input lengths\noutput_binned = np.digitize(output_lengths, output_bins)  # Categorize based on output lengths\n\n# Count how many phrases fall into each bin\ninput_bin_counts = [sum(input_binned == i) for i in range(1, len(input_bins))]\noutput_bin_counts = [sum(output_binned == i) for i in range(1, len(output_bins))]\n\n# Plotting the bar charts\nplt.figure(figsize=(20, 10))\n\n# Plotting the input phrase lengths\nplt.subplot(1, 2, 1)\nplt.bar(bin_labels_input, input_bin_counts, color='skyblue', edgecolor='black')\nplt.title('Input Phrases Length Distribution')\nplt.xlabel('Length Range')\nplt.ylabel('Number of Phrases')\n\n# Plotting the output phrase lengths\nplt.subplot(1, 2, 2)\nplt.bar(bin_labels_output, output_bin_counts, color='skyblue', edgecolor='black')\nplt.title('Output Phrases Length Distribution')\nplt.xlabel('Length Range')\nplt.ylabel('Number of Phrases')\n\n# Show the plots\nplt.tight_layout()\nplt.savefig(\"sentences_length.png\")\nplt.show()","metadata":{"id":"d9lEMELsiqIx","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Extra Plot\nWordCloud to show the most frequent words in input and output set\n","metadata":{"id":"PJljpW-Jc37a"}},{"cell_type":"code","source":"# Join together all words from input set\ninput = ' '.join(input_phrases)\n\n# Join together all words from output set\noutput = ' '.join(output_phrases)\n\n# Instantiate WC for input and for output\nwordcloud1 = WordCloud(width=800, height=400, background_color='white').generate(input)\nwordcloud2 = WordCloud(width=800, height=400, background_color='white').generate(output)\n\n# Plotting the bar charts\nplt.figure(figsize=(20, 10))\n\n# Plotting the input words\nplt.subplot(1, 2, 1)\nplt.imshow(wordcloud1, interpolation='bilinear')\nplt.axis('off')\nplt.title('Question Word Cloud')\n\n# Plotting the output words\nplt.subplot(1, 2, 2)\nplt.imshow(wordcloud2, interpolation='bilinear')\nplt.axis('off')\nplt.title('Answer Word Cloud')\n\n# Show\nplt.tight_layout()\nplt.savefig(\"question_word_cloud.png\")\nplt.show()","metadata":{"id":"M5PVKri8c37a","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Word2Vec","metadata":{"id":"f4wIndWS2lyy"}},{"cell_type":"markdown","source":"Word2Vec transforms the words in the dataset into numerical vectors based on the context in which they appear in the text, creating dense word representations that capture their semantic and syntactic relationships. In this notebook, we are using it to analyze the words in the dataset and generate graphs that visualize the relationships and similarities between words, providing a visual representation of the linguistic connections within the corpus.","metadata":{}},{"cell_type":"code","source":"# Tokenize with spaces or any non word character\ntokenized_sentences = [re.sub('\\W', ' ', sentence).lower().split() for sentence in input_phrases]\n\n# remove sentences that are only 1 word long\ntokenized_sentences = [sentence for sentence in tokenized_sentences if len(sentence) > 1]\n\nfor sentence in tokenized_sentences[:5]:\n    print(sentence)","metadata":{"id":"HUU3TXC5c37a","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Create Word2Vec (ignore words with frequency less than 5)\nWordModel = Word2Vec(tokenized_sentences, vector_size=30, min_count=5, window=10)","metadata":{"id":"e05BkgN-2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"sample = random.sample(list(WordModel.wv.key_to_index), 500)\nword_vectors = WordModel.wv[sample]","metadata":{"id":"blQAYdGr2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### 3D plot with words","metadata":{"id":"CskZCC7a2lyz"}},{"cell_type":"markdown","source":"Now we use [t-SNE](https://www.datacamp.com/tutorial/introduction-t-sne) (t-distributed Stochastic Neighbor Embedding) for data exploration and visualizing high-dimensional data.","metadata":{"id":"OPSytXCQlGy_"}},{"cell_type":"markdown","source":"Visualize a plot with dataset words' in the 3D space","metadata":{"id":"NwSdGIXKlaPm"}},{"cell_type":"code","source":"# Apply t-SNE to word vectors\ntsne = TSNE(n_components=3, n_iter=2000)\ntsne_embedding = tsne.fit_transform(word_vectors)\n\n# Extract individual dimensions\nx, y, z = np.transpose(tsne_embedding)\n\n# Create 3D scatter plot with a subset of the dataset\nfig = px.scatter_3d(x=x[:200],y=y[:200],z=z[:200],text=sample[:200])\n# Full dataset\n# fig = px.scatter_3d(x=x,y=y,z=z,text=sample)\nfig.update_traces(marker=dict(size=3,line=dict(width=2)),textfont_size=10)\nfig.show()","metadata":{"id":"_0EGqbEA2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"Here's a question of the dataset: \"What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?\" now we want to see if Mg2 and Ca2 are close in the 3D space.","metadata":{}},{"cell_type":"code","source":"first_question = ['mg2', 'ca2']\n\nword_vectors = WordModel.wv[first_question+sample]\n\ntsne = TSNE(n_components=3)\ntsne_embedding = tsne.fit_transform(word_vectors)\n\nx, y, z = np.transpose(tsne_embedding)","metadata":{"id":"ruxbKZBt2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"r = (-30,30)\nfig = px.scatter_3d(x=x, y=y, z=z, range_x=r, range_y=r, range_z=r, text=first_question + [None] * 500)\nfig.update_traces(marker=dict(size=3,line=dict(width=2)),textfont_size=10)\nfig.show()","metadata":{"id":"Cem4IiRb2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Use Word2Vec to find most similar words\nWordModel.wv.most_similar('ca2')","metadata":{"id":"E795I6ej2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Example of King – Man + Woman = Queen \nvec = WordModel.wv.get_vector('headache') + (WordModel.wv.get_vector('fever') - WordModel.wv.get_vector('drug'))\nWordModel.wv.similar_by_vector(vec)","metadata":{"id":"8RkQLdpv2lyz","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Dataset Class","metadata":{"id":"t__NU-CxtprY"}},{"cell_type":"markdown","source":"We create a `MedDataset` class to ensure consistency between the format of the MedicalMeadow dataset and how it will be used during the training process with roles supported by Llama [text models](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2).","metadata":{}},{"cell_type":"code","source":"class MedDataset(Dataset):\n    def __init__(self, dataset, tokenizer):\n        self.dataset = dataset\n        self.tokenizer = tokenizer\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, idx):\n        example = self.dataset[idx]\n        \n        # Construct a conversation-style prompt with instruction, input, and output\n        messages = [\n            {\"role\": \"system\", \"content\": example['instruction']},  # Starting instruction\n            {\"role\": \"user\", \"content\": example['input']},  # Input provided by the user\n            {\"role\": \"assistant\", \"content\": example['output']}  # Expected output from the assistant\n        ]\n\n        # Create a prompt using the tokenizer's chat template\n        prompt = self.tokenizer.apply_chat_template(\n            messages,\n            tokenize=False,  # The prompt remains untokenized for now\n            add_generation_prompt=True  # Add any required generation-specific tokens\n        )\n\n        # Tokenize the prompt into a format suitable for model input\n        tokens = self.tokenizer(\n            prompt,\n            padding=\"max_length\",  # Pad sequences to a fixed maximum length\n            truncation=True,  # Truncate sequences that exceed the maximum length\n            max_length=128,  # Set the maximum sequence length\n            return_tensors=\"pt\" \n        )\n\n        tokens['labels'] = tokens['input_ids'].clone()\n        # Mask padding tokens in the labels to ignore them during loss computation\n        tokens['labels'][tokens['input_ids'] == self.tokenizer.pad_token_id] = -100\n\n        return {\n            \"input_ids\": tokens['input_ids'].squeeze(),  # Tokenized input sequence\n            \"attention_mask\": tokens['attention_mask'].squeeze(),  # Attention mask for the input\n            \"labels\": tokens['labels'].squeeze()  # Labels for the model to predict\n        }","metadata":{"id":"etwCQ4z6tprY","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Split dataset\ntrain_set, val_set, test_dataset = random_split(ds, [0.8, 0.1, 0.1])","metadata":{"id":"IUEN8DT5YzQR","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Model Creation\nThis section imports and sets the model to be fine-tuned","metadata":{"id":"-BMh1vbuc5qp"}},{"cell_type":"markdown","source":"### Load [Llama 3.2 1B Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)\nThe model chosen to be fine-tuned is Llama, specifically version 3.2 which has a variant with only 1B parameters and is optimized for chat-style question answering (Instruct)","metadata":{"id":"17sDebt1YzQR"}},{"cell_type":"code","source":"# Base model from Hugging Face\nbase_model = \"meta-llama/Llama-3.2-1B-Instruct\"\n\n# Load tokenizer\ntokenizer = AutoTokenizer.from_pretrained(base_model)\ntokenizer.pad_token = tokenizer.eos_token","metadata":{"id":"AelKMq1xSsQ-","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"We only initialize the dataset class for train and validation now since we needed the tokenizer, which is useful to apply the template for chat. While for the testset we are going to use a different template.","metadata":{"id":"JhcOiPoxYzQR"}},{"cell_type":"code","source":"# Create datasets for training and validation\ntrain_dataset = MedDataset(train_set, tokenizer)\nval_dataset = MedDataset(val_set, tokenizer)\n\n# Print dataset dimensions\nprint(f\"Train dataset dimension: {len(train_dataset)}\")\nprint(f\"Validation dataset dimension: {len(val_dataset)}\")\nprint(f\"Test dataset dimension: {len(test_dataset)}\")","metadata":{"id":"U775LBiXYzQR","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"In the next cell we configured BitsandBytes, to apply 4-bit quantization.\n It allows us to reduce the precision of the numbers used to represent the model weights, saving memory space and improving the speed of inference","metadata":{"id":"nCE-rr6AYzQR"}},{"cell_type":"code","source":"# Set dtype for quantization\ncompute_dtype = getattr(torch, \"float16\")\n\n# Set quantization config for BitsAndBytes\nquant_config = BitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_quant_type=\"nf4\",\n    bnb_4bit_compute_dtype=compute_dtype,\n    bnb_4bit_use_double_quant=False,\n    bnb_4bit_representation=\"nested\"\n)","metadata":{"id":"gR4BtF8AHD0I","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"We load the model from Hugging Face","metadata":{"id":"eFr1wdv2YzQR"}},{"cell_type":"code","source":"# Load LLAMA model with quantization\nmodel = AutoModelForCausalLM.from_pretrained(\n    base_model,\n    quantization_config=quant_config,\n    device_map={\"\": 0},\n    torch_dtype=torch.float32,\n    trust_remote_code=True\n)\nmodel.config.use_cache = False\nmodel.config.pretraining_tp = 1\n\nmodel.gradient_checkpointing_enable()\nmodel = prepare_model_for_kbit_training(model)","metadata":{"id":"AAl1RmdGYzQR","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Model training","metadata":{"id":"XMNZAui5YzQS"}},{"cell_type":"markdown","source":"Relative to training, in order to minimize the fine-tuning time and without having to retrain all model parameters, we used the PEFT technique, to update only a small part of parameters useful for the task","metadata":{"id":"CTGJKF_XYzQS"}},{"cell_type":"code","source":"peft_params = LoraConfig(\n    lora_alpha=32,  # Scaling factor for the LoRA updates to control the adaptation strength\n    lora_dropout=0.1, \n    r=16,  # Rank of the low-rank adaptation matrices; smaller values reduce the number of trainable parameters\n    bias=\"none\",\n    task_type=\"CAUSAL_LM\",  # Causal language modeling task\n)","metadata":{"id":"ew98HuwaKW1P","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"model.train()\n\n# Define training parameters using the TrainingArguments class\ntraining_params = TrainingArguments(\n    output_dir=\"./results\",  # Directory to save model checkpoints and training outputs\n    num_train_epochs=1,\n    per_device_train_batch_size=8,  # Batch size for each device during training\n    gradient_accumulation_steps=4,  # Number of steps to accumulate gradients before updating weights\n    optim=\"paged_adamw_32bit\",\n    eval_strategy=\"steps\",  # Evaluation frequency is defined by the number of steps\n    logging_steps=90,\n    eval_steps=90,\n    learning_rate=2e-4,  \n    weight_decay=0.001,  # Weight decay factor for regularization to prevent overfitting\n    fp16=False,  # Disable 16-bit floating-point precision for training\n    bf16=False,  # Disable bfloat16 precision for training\n    max_grad_norm=0.3,  # Maximum gradient norm for gradient clipping\n    max_steps=-1,  # Total number of training steps (-1 means determined by epochs)\n    warmup_ratio=0.03,  # Fraction of steps for learning rate warmup\n    group_by_length=True,  # Group samples of similar lengths for more efficient training\n    lr_scheduler_type=\"constant\", \n    gradient_checkpointing=True  # Enable gradient checkpointing to save memory\n)","metadata":{"id":"n8dBHX-j-M1J","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"For fine-tuning we use Supervised Fine-Tuning Trainer, passing the parameters set previously","metadata":{"id":"DzC-_09pYzQS"}},{"cell_type":"code","source":"trainer = SFTTrainer(\n    model=model,\n    train_dataset=train_dataset,\n    eval_dataset=val_dataset,\n    peft_config=peft_params,\n    max_seq_length=256,\n    tokenizer=tokenizer,\n    args=training_params,\n    packing=False,\n)","metadata":{"id":"ARnpWrEFSsQ_","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"Uncomment the below cell to train the model","metadata":{}},{"cell_type":"code","source":"# Train the model\n# trainer.train()\n\n# Save the model and tokenizer\n# trainer.save_model(\"./fine-tuned-model\")\n# tokenizer.save_pretrained(\"./fine-tuned-model\")","metadata":{"id":"ARnpWrEFSsQ_","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Plots","metadata":{"id":"a_6XDhvKc37b"}},{"cell_type":"markdown","source":"The below cell plots the validation and training loss. If you execute the previous training, uncomment also this cell","metadata":{}},{"cell_type":"code","source":"# df = pd.DataFrame(trainer.state.log_history)\n\n# # Plot training and validation loss\n# plt.figure(figsize=(10, 6))\n# plt.plot(df['loss'], label='Training Loss')\n# plt.title('Training and Validation Loss')\n# plt.xlabel('Steps')\n# plt.ylabel('Loss')\n# plt.legend()\n# plt.grid(True)\n\n# plt.savefig(\"training_and_validation_loss.png\")\n# plt.show()","metadata":{"id":"5DaDCbaMc37b","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### 0-shot fine tuning","metadata":{}},{"cell_type":"code","source":"# Define the model for 0-shot\nmodel0 = AutoModelForCausalLM.from_pretrained(\n    base_model,\n    quantization_config=quant_config,\n    device_map={\"\": 0},\n    torch_dtype=torch.float16,\n    trust_remote_code=True\n)\nmodel0.config.use_cache = False\nmodel0.config.pretraining_tp = 1","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Few-shot fine tuning","metadata":{}},{"cell_type":"code","source":"# Prepare the dataset for few-shot fine tuning (just 5 instances)\nsubset = Subset(ds, list(range(5)))\n\nfew_dataset = MedDataset(subset, tokenizer)\nfew_dataset","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Define the model\nmodelFEW = AutoModelForCausalLM.from_pretrained(\n    base_model,\n    quantization_config=quant_config,\n    device_map={\"\": 0},\n    torch_dtype=torch.float16,\n    trust_remote_code=True\n)\nmodelFEW.config.use_cache = False\nmodelFEW.config.pretraining_tp = 1\n\ntraining_params = TrainingArguments(\n    output_dir=\"./results\",  # Directory to save model checkpoints and training outputs\n    num_train_epochs=1,\n    per_device_train_batch_size=8,  # Batch size for each device during training\n    gradient_accumulation_steps=4,  # Number of steps to accumulate gradients before updating weights\n    optim=\"paged_adamw_32bit\",\n    logging_steps=90,\n    learning_rate=2e-4,  \n    weight_decay=0.001,  # Weight decay factor for regularization to prevent overfitting\n    fp16=False,  # Disable 16-bit floating-point precision for training\n    bf16=False,  # Disable bfloat16 precision for training\n    max_grad_norm=0.3,  # Maximum gradient norm for gradient clipping\n    max_steps=-1,  # Total number of training steps (-1 means determined by epochs)\n    warmup_ratio=0.03,  # Fraction of steps for learning rate warmup\n    group_by_length=True,  # Group samples of similar lengths for more efficient training\n    lr_scheduler_type=\"constant\", \n    gradient_checkpointing=True,  # Enable gradient checkpointing to save memory\n    report_to=\"none\",\n)\n\ntrainer = SFTTrainer(\n    model=modelFEW,\n    train_dataset=few_dataset,\n    peft_config=peft_params,\n    max_seq_length=256,\n    tokenizer=tokenizer,\n    args=training_params,\n    packing=False,\n)\n# Few shot training\ntrainer.train()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def inference(question, model_inf, tokenizer_inf):\n    messages = [{\"role\": \"system\", \"content\": instructions[0]},\n        {\"role\": \"user\", \"content\": question}]\n    \n    prompt = tokenizer_inf.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n    \n    model_inputs = tokenizer_inf(prompt, return_tensors='pt', padding=True, truncation=True).to(\"cuda\")\n    \n    outputs = model_inf.generate(**model_inputs, max_new_tokens=128)\n    \n    response = tokenizer_inf.decode(outputs[0], skip_special_tokens=True)\n\n    return response.partition('assistant')[2][2:]","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"question = test_dataset[0]\nquestion","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"responseFEW = inference(question['input'], modelFEW, tokenizer)\nprint(\"Response FEW shot: \", responseFEW)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"response0 = inference(question['input'], model0, tokenizer)\nprint(\"\\nResponse 0 shot: \", response0)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Model Test","metadata":{}},{"cell_type":"code","source":"# Change with model folder\ntrained_model = \"/kaggle/input/medicalllm/pytorch/default/1/model-chatbot-medical-mew\"\n\nmodel = AutoPeftModelForCausalLM.from_pretrained(\n    trained_model,\n    quantization_config=quant_config,\n    device_map={\"\": 0},\n    torch_dtype=torch.float32,\n    trust_remote_code=True\n)\ntokenizer = AutoTokenizer.from_pretrained(trained_model)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"reponse = inference(question['input'], model, tokenizer)\n\nreponse","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Compare our model with MedAlpaca","metadata":{"id":"uJGGWVExc37f"}},{"cell_type":"markdown","source":"We want to use [MedAlpaca](https://huggingface.co/medalpaca/medalpaca-7b) as benchmark to test the quality of our fine-tuning.\n\n***medalpaca-7b*** is a large language model specifically fine-tuned for medical domain tasks. It is based on LLaMA (Large Language Model Meta AI) and contains 7 billion parameters. The primary goal of this model is to improve question-answering and medical dialogue tasks.","metadata":{"id":"ennfkv3vjRbE"}},{"cell_type":"code","source":"# Function to download MedAlpaca model and tokenizer\ndef setup_medAlpaca():\n    try:\n        # try not quantize model\n        model = AutoModelForCausalLM.from_pretrained(\n            \"medalpaca/medalpaca-7b\",\n            trust_remote_code=True,\n            device_map='auto',\n            torch_dtype=torch.float16\n        )\n\n        tokenizer = AutoTokenizer.from_pretrained(\"medalpaca/medalpaca-7b\")\n\n        return model, tokenizer\n\n    except Exception as e:\n        print(f\"Error: {str(e)}\")\n        print(f\"Python version: {sys.version}\")\n        if torch.cuda.is_available():\n            print(f\"GPU: {torch.cuda.get_device_name()}\")\n        return None, None","metadata":{"id":"6eAja8tHc37f","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Setup\nmodelMED, tokenizerMED = setup_medAlpaca()","metadata":{"id":"OomPFh1yc37f","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"To evaluate our model response we use [BLEU](https://www.nltk.org/_modules/nltk/translate/bleu_score.html) score. BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1 representing more similar texts.","metadata":{"id":"JBj_UlApj4Fm"}},{"cell_type":"code","source":"nltk.download('punkt')\n\n# BLUE score between the candidate (generated answer) and the reference answer (taken from dataset)\ndef calculate_bleu(reference, candidate):\n    try:\n      score = sentence_bleu(reference, candidate)\n      return score\n    except Exception as e:\n      print(f\"Error during BLUE computing: {e}\")\n      return None","metadata":{"id":"bgPTQB_qc37f","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Take 10 instances from testset\ntest = test_dataset[:10]","metadata":{"id":"GAS31dMvc37f","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"test","metadata":{"id":"EbTmer27xNT5","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"alpaca_bleu = []\nour_bleu = []","metadata":{"id":"_vyh3DvPc37f","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Preprompt used also from our model\ninstruction = \"Answer this question truthfully: \"\n\n# medAlpaca\nprint(\"MedAlpaca\")\n\n# A question is in the field input of an instance of the testset\nfor i, question in enumerate(test['input']):\n    print(f\"\\nQuestion: {question}\")\n\n    input_text = instruction + question\n\n    # Tokenize\n    input_ids = tokenizerMED(input_text, return_tensors='pt').to(modelMED.device)[\"input_ids\"]\n\n    # Generate response\n    outputs = modelMED.generate(input_ids, max_new_tokens=128)\n    response = tokenizerMED.decode(outputs[0])\n\n    # Clean response by remove the question (NOT WORKING)\n    clean_response = re.sub(r'.*?\\? ', '', response, flags=re.DOTALL)\n\n    print(\"   Response MedAlpaca:\", clean_response)\n\n    # Comput BLEU score between generated answer and answer in testset\n    candidate = clean_response.split()\n    bleu_score = calculate_bleu(test['output'][i], candidate)\n\n    if bleu_score is not None:\n        print(f\"BLEU score for Llama answer {i}: {bleu_score}\")\n        # Append score to a list\n        alpaca_bleu.append(bleu_score)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"#### Test our model","metadata":{}},{"cell_type":"code","source":"# Same code but for our model\nprint(\"\\nOur Model\")\n\nfor i, question in enumerate(test['input']):\n    print(f\"\\nQuestion: {question}\")\n\n    response = inference(question, model, tokenizer)\n\n    print(\"Response Medical Meadow: \", response)\n\n    # Compute BLEU score between generated answer and answer in testset\n    candidate = response.split()\n    bleu_score = calculate_bleu(test['output'][i], candidate)\n\n    if bleu_score is not None:\n        print(f\"BLEU score for Medical Meadow answer {i}: {bleu_score}\")\n        # Append score to a list\n        our_bleu.append(bleu_score)","metadata":{"id":"sMpKFJWSyCrJ","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Table to view all scores\ntable = list(zip([0,1,2,3,4,5,6,7,8,9], alpaca_bleu, our_bleu))\n\nprint('BLEU score')\nprint(tabulate(table, headers=['Question', 'MedAlpaca', 'MedicalMeadow'], tablefmt='grid'))\n\n# Average BLEU score for both models\nbleu_alpaca = np.mean(alpaca_bleu)\nbleu_med = np.mean(our_bleu)\nprint(f\"\\n-> Average BLEU score for MedAlpaca model: {bleu_alpaca}\")\nprint(f\"-> Average BLEU score for MedicalMeadow model: {bleu_med}\")","metadata":{"id":"72irnssYc37f","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Add voice interactivity","metadata":{"id":"kPKlNEM_sFu0"}},{"cell_type":"markdown","source":"### Record Audio","metadata":{"id":"3QSzEjW2MZrc"}},{"cell_type":"code","source":"\"\"\"\nReferences:\nhttps://blog.addpipe.com/recording-audio-in-the-browser-using-pure-html5-and-minimal-javascript/\nhttps://stackoverflow.com/a/18650249\nhttps://hacks.mozilla.org/2014/06/easy-audio-capture-with-the-mediarecorder-api/\nhttps://air.ghost.io/recording-to-an-audio-file-using-html5-and-js/\nhttps://stackoverflow.com/a/49019356\n\"\"\"\n\nAUDIO_HTML = \"\"\"\n<script>\nvar my_div = document.createElement(\"DIV\");\nvar my_p = document.createElement(\"P\");\nvar my_btn = document.createElement(\"BUTTON\");\nvar t = document.createTextNode(\"Press to start recording\");\n\nmy_btn.appendChild(t);\n//my_p.appendChild(my_btn);\nmy_div.appendChild(my_btn);\ndocument.body.appendChild(my_div);\n\nvar base64data = 0;\nvar reader;\nvar recorder, gumStream;\nvar recordButton = my_btn;\n\nvar handleSuccess = function(stream) {\n  gumStream = stream;\n  var options = {\n    //bitsPerSecond: 8000, //chrome seems to ignore, always 48k\n    mimeType : 'audio/webm;codecs=opus'\n    //mimeType : 'audio/webm;codecs=pcm'\n  };\n  //recorder = new MediaRecorder(stream, options);\n  recorder = new MediaRecorder(stream);\n  recorder.ondataavailable = function(e) {\n    var url = URL.createObjectURL(e.data);\n    var preview = document.createElement('audio');\n    preview.controls = true;\n    preview.src = url;\n    document.body.appendChild(preview);\n\n    reader = new FileReader();\n    reader.readAsDataURL(e.data);\n    reader.onloadend = function() {\n      base64data = reader.result;\n      //console.log(\"Inside FileReader:\" + base64data);\n    }\n  };\n  recorder.start();\n  };\n\nrecordButton.innerText = \"Recording... press to stop\";\n\nnavigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);\n\n\nfunction toggleRecording() {\n  if (recorder && recorder.state == \"recording\") {\n      recorder.stop();\n      gumStream.getAudioTracks()[0].stop();\n      recordButton.innerText = \"Saving the recording... pls wait!\"\n  }\n}\n\n// https://stackoverflow.com/a/951057\nfunction sleep(ms) {\n  return new Promise(resolve => setTimeout(resolve, ms));\n}\n\nvar data = new Promise(resolve=>{\n//recordButton.addEventListener(\"click\", toggleRecording);\nrecordButton.onclick = ()=>{\ntoggleRecording()\n\nsleep(2000).then(() => {\n  // wait 2000ms for the data to be available...\n  // ideally this should use something like await...\n  //console.log(\"Inside data:\" + base64data)\n  resolve(base64data.toString())\n\n});\n\n}\n});\n\n</script>\n\"\"\"\n\ndef get_audio():\n  display(HTML(AUDIO_HTML))\n  data = eval_js(\"data\")\n  binary = b64decode(data.split(',')[1])\n\n  process = (ffmpeg\n    .input('pipe:0')\n    .output('pipe:1', format='wav')\n    .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)\n  )\n  output, err = process.communicate(input=binary)\n\n  riff_chunk_size = len(output) - 8\n  # Break up the chunk size into four bytes, held in b.\n  q = riff_chunk_size\n  b = []\n  for i in range(4):\n      q, r = divmod(q, 256)\n      b.append(r)\n\n  # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.\n  riff = output[:4] + bytes(b) + output[8:]\n\n  sr, audio = wav_read(io.BytesIO(riff))\n\n  return audio, sr","metadata":{"id":"7uLKcUMQMeaL","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Register audio\naudio, sr = get_audio()","metadata":{"id":"oo1ink-uMiYo","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Save audio to file\nscipy.io.wavfile.write('./recording.wav', sr, audio)","metadata":{"id":"DaV9sVvZMlCL","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Speech to Text","metadata":{"id":"_1k0fMEMNtl2"}},{"cell_type":"markdown","source":"We analyze the voice recording in the form of spectrograms and perform resempling.\n\nWe tried normalize the audio but it will get a worse result. We resample at 16000 because for most speech-focused tasks, 16,000 Hz is optimal (8000 Hz is for telephones that have low bandwith and 44000 is for music).","metadata":{"id":"1uRvjNElc37g"}},{"cell_type":"code","source":"# Resample audio\ntarget_sample_rate = 16000\naudio, sr = librosa.load(\"recording.wav\", sr=None)  # Load with original sampling rate\nif sr != target_sample_rate:\n    audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)\n\n# Save processed audio\nsf.write(\"processed_audio.wav\", audio, target_sample_rate)\n\nprint(f\"Preprocessed audio saved as processed_audio.wav\")","metadata":{"id":"7zMag0Huc37g","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"whis = whisper.load_model(\"base\")\n\n# Load audio and pad/trim it to fit 30 seconds\naudio = whisper.load_audio(\"./processed_audio.wav\")\naudio = whisper.pad_or_trim(audio)\n\n# Plot the audio\nfig = plt.figure(figsize=(16,4))\nplt.plot(audio, linewidth=0.4)\nplt.ylabel('Amplitude')\nplt.xlabel('Samples')\nplt.show()\n\n# Move log-Mel spectrogram to the same device as the model\nmel = whisper.log_mel_spectrogram(audio).to(whis.device)\n\n# Visualize spectrogram\nfig = plt.figure(figsize=(10,4))\nplt.pcolormesh(mel.cpu().numpy())\nplt.colorbar(label='Power [dB]')\nplt.ylabel('Frequency [Hz]')\nplt.xlabel('Time [10ms]')\nplt.show()\n\n# Use the mel spectrogram to detect the language\n_, probs = whis.detect_language(mel)\nlang = max(probs, key=probs.get)\n\n# Print result\nprint(f\"Detected language: {lang}, confidence: {probs[lang]:.3f}\")\n\n# Decode the audio\noptions = whisper.DecodingOptions(fp16 = False)\nresult = whisper.decode(whis, mel, options)\nprint(result.text)","metadata":{"id":"sPQmD992Lpjh","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Dump result text into model\ninputs = tokenizer(result.text, return_tensors=\"pt\")\ninputs = {k: v.to(model.device) for k, v in inputs.items()}\noutput = model.generate(**inputs)\nprint(tokenizer.decode(output[0]))","metadata":{"id":"La6g7PIIN3yo","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Text to Speech","metadata":{"id":"LZLo43qaNyqn"}},{"cell_type":"markdown","source":"We construct the syntactic voice from spectrograms through the use of `tacotron` and `waveglow`","metadata":{}},{"cell_type":"code","source":"# Load Tacotron model\ntacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')\ntacotron2 = tacotron2.to('cuda')\ntacotron2.eval()","metadata":{"id":"DVEewwceOF-Y","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Load Waveglow\nwaveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp16')\nwaveglow = waveglow.remove_weightnorm(waveglow)\nwaveglow = waveglow.to('cuda')\nwaveglow.eval()","metadata":{"id":"tcKqPWGXOJz9","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Add model response\nprint(tokenizer.decode(output[0]))\ntext = tokenizer.decode(output[0])","metadata":{"id":"DlHbL8_YOMON","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Load tts utils\nutils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')\nsequences, lengths = utils.prepare_input_sequence([text])\nsequences","metadata":{"id":"SYGrvo4YOOBX","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"with torch.no_grad():\n    mel, _, _ = tacotron2.infer(sequences, lengths)\n\n%matplotlib inline\n\n# Plot the voice spectrogram\nfig = plt.figure(figsize=(10,4))\nplt.pcolormesh(mel[0].cpu().numpy())\nplt.colorbar(label='Power [dB]')\nplt.ylabel('Frequency [Hz]')\nplt.xlabel('Time [10ms]')\nplt.show()","metadata":{"id":"oOyIqRq9OQJP","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"with torch.no_grad():\n    audio = waveglow.infer(mel)\naudio_numpy = audio[0].data.cpu().numpy()\nrate = 22050\n\n# Plot the amplitude over samples\nfig = plt.figure(figsize=(16,4))\nplt.plot(audio_numpy, linewidth=0.4)\nplt.ylabel('Amplitude')\nplt.xlabel('Samples')\nplt.show()","metadata":{"id":"t-dE1l-dOVwn","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Play response audio\nAudio(audio_numpy, rate=rate)","metadata":{"id":"PNwE-1SuObFF","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Potential Extensions","metadata":{"id":"mgu3JQjGc37i"}},{"cell_type":"markdown","source":"### Improved Audio\n\nWe use a bigger model to recognize the speech, tha basic model is whisper but now we use the larger version.","metadata":{"id":"9C_3YgMPtprd"}},{"cell_type":"code","source":"# Setting the device\ndevice = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\ntorch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n\n# Loading the model\nmodel_id = \"openai/whisper-large-v3\"\n\nmodel_stp = AutoModelForSpeechSeq2Seq.from_pretrained(\n    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True\n)\nmodel_stp.to(device)\nprocessor = AutoProcessor.from_pretrained(model_id)\n\npipe = pipeline(\n    \"automatic-speech-recognition\",\n    model=model_stp,\n    tokenizer=processor.tokenizer,\n    feature_extractor=processor.feature_extractor,\n    torch_dtype=torch_dtype,\n    device=device,\n)\n\n# Inference the model\nresult = pipe(\"recording.wav\")\nprint(result[\"text\"])","metadata":{"id":"eSQcFys7c37i","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Inference the model with our question\nquestion = result[\"text\"]\nprint(question)\n\n# Loading the fine-tuned model and the tokenizer for inference\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"model\",\n        max_seq_length = 2048,\n        dtype = None,\n        load_in_4bit = True,\n    )\n\n# Enable faster inference\nFastLanguageModel.for_inference(model)\n\n# Set the right template for the question\nmessages = [\n    {\"role\": \"user\", \"content\": f\"Answer this question truthfully:{question}\"},\n]\n\n# Standard generation\ninputs = tokenizer.apply_chat_template(\n    messages,\n    tokenize = True,\n    add_generation_prompt = True,\n    return_tensors = \"pt\",\n).to(\"cuda\")\n\nprint(\"\\n Generation:\")\nstreamer = TextStreamer(tokenizer, skip_prompt = True)\n\nresponse = model.generate(\n    input_ids = inputs,\n    streamer = streamer,\n    max_new_tokens = 64,\n    use_cache = True,\n    temperature = 0.7,\n    min_p = 0.1\n)\n\n# Get the response\nresp = tokenizer.decode(response[0], skip_special_tokens=True)\nresponse = resp.partition('assistant')[2][2:]","metadata":{"id":"6H9L3uHQc37i","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"We use a different model for Text-to-speech to achieve a custom voice","metadata":{}},{"cell_type":"code","source":"# Load text to speech model\ntts = TTS(\"tts_models/multilingual/multi-dataset/xtts_v2\")\ntts.to(device)\n\n# Generate speech by cloning a voice\ntts.tts_to_file(text=response,\n                file_path=\"output_tts.wav\",\n                speaker_wav=\"obama_audio.mp3\",\n                language=\"en\")","metadata":{"id":"S56zwOs2c37i","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Audio with generated response\nAudio(\"output_tts.wav\")","metadata":{"id":"zFCg_Vsdc37i","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Chat Bot","metadata":{"id":"SPypvRtFo9y3"}},{"cell_type":"markdown","source":"Now we want to use our fine-tuned model as a chatbot which in a continuos way responds to the user's questions.","metadata":{}},{"cell_type":"code","source":"# Initialise dialogue history\ndialogue_history = [\"Hello, I'm a medical assistant chatbot, how can I help you?\\n\"]\n\n# Start chatting\nprint(\"Press [Ctrl-C] to stop\\n\\n\\n\\n\")\nprint(f\"Chatbot: {dialogue_history[0]}\")\n# Keep talking until stop\nrunning = True\nwhile running:\n    try:\n        # Read user message\n        user_message = input(\"User: \")\n        # Append message to dialogue history\n        dialogue_history.append(user_message)\n        # Search for a chatbot response\n        text = inference(user_message, model, tokenizer)\n        \n        # Append chatbot response to dialogue history\n        dialogue_history.append(text)\n        # Print chatbot response\n        print(f\"Chatbot: {text}\\n\")\n    except KeyboardInterrupt:\n        running = False","metadata":{"id":"0pnRhiT6Di2b","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Unsloth","metadata":{"id":"oSvH1RiqHVIv"}},{"cell_type":"markdown","source":"\n[Unsloth](https://unsloth.ai) provides an efficient framework for deploying large-scale language models with minimal computational and memory overhead. By incorporating 4-bit quantization, Unsloth models achieve faster inference and significantly reduce hardware requirements without compromising performance.\n\nThis makes it an excellent choice for:\n- Large-scale text generation tasks\n- Real-time applications\n- Deployments in resource-constrained environments\n\nAdditionally,  offers a diverse set of pre-trained and fine-tuned models for various use cases, ensuring flexibility and adaptability to a wide range of NLP tasks.\n","metadata":{"id":"Cfh79ivuuIs_"}},{"cell_type":"code","source":"# Define maximum sequence length; RoPE Scaling is auto-supported for longer sequences.\nmax_seq_length = 2048\n\n# Set computation data type. Auto-detected by default; use Float16 or Bfloat16 for specific GPUs.\ndtype = None\n\n# Enable 4-bit quantization to save memory and boost inference speed.\nload_in_4bit = True\n\n# Pre-quantized 4-bit models for faster downloads and reduced memory usage.\nfourbit_models = [\n    \"unsloth/Meta-Llama-3.1-8B-bnb-4bit\",\n    \"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit\",\n    \"unsloth/Meta-Llama-3.1-70B-bnb-4bit\",\n    \"unsloth/Meta-Llama-3.1-405B-bnb-4bit\",\n    \"unsloth/Mistral-Small-Instruct-2409\",\n    \"unsloth/mistral-7b-instruct-v0.3-bnb-4bit\",\n    \"unsloth/Phi-3.5-mini-instruct\",\n    \"unsloth/Phi-3-medium-4k-instruct\",\n    \"unsloth/gemma-2-9b-bnb-4bit\",\n    \"unsloth/gemma-2-27b-bnb-4bit\",\n    \"unsloth/Llama-3.2-1B-bnb-4bit\",\n    \"unsloth/Llama-3.2-1B-Instruct-bnb-4bit\",\n    \"unsloth/Llama-3.2-3B-bnb-4bit\",\n    \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\",\n    \"unsloth/Llama-3.3-70B-Instruct-bnb-4bit\"\n]\n\n# Load the specified pre-trained model and tokenizer with chosen settings.\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name = \"unsloth/Llama-3.2-3B-Instruct\",\n    dtype = dtype,\n    load_in_4bit = load_in_4bit,\n)\n","metadata":{"id":"TtOVoYxocULS","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"LoRA adapters","metadata":{"id":"vR87HRbfce2i"}},{"cell_type":"code","source":"# Configure the PEFT (Parameter-Efficient Fine-Tuning) model with LoRA (Low-Rank Adaptation).\nmodel = FastLanguageModel.get_peft_model(\n    model,\n    r = 16,  # LoRA rank; suggested values: 8, 16, 32, 64, 128\n    target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n                      \"gate_proj\", \"up_proj\", \"down_proj\"],  # Modules to apply LoRA\n    lora_alpha = 16,  # Scaling factor for LoRA\n    lora_dropout = 0,  # Dropout rate; 0 is optimized for most cases\n    bias = \"none\",  # Bias mode; \"none\" is optimized for minimal overhead\n    use_gradient_checkpointing = \"unsloth\",  # Optimized gradient checkpointing for long contexts\n    random_state = 3407,  # Seed for reproducibility\n    use_rslora = False,  # Enable Rank Stabilized LoRA\n    loftq_config = None,  # Config for LoFT-Q quantization\n)","metadata":{"id":"G-0tWbxschYi","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"Format Template","metadata":{"id":"4MovmBY3c3Rm"}},{"cell_type":"code","source":"# Load the dataset\ndataset = load_dataset(\"medalpaca/medical_meadow_medical_flashcards\", split=\"train\")\n\n# Set up the tokenizer with Llama-3.1 chat template\ntokenizer = get_chat_template(\n    tokenizer,\n    chat_template=\"llama-3.1\",\n)","metadata":{"id":"5w8bQdnhc9G3","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def format_examples(examples):\n    \"\"\"Format the examples while maintaining original columns\"\"\"\n    texts = []\n    for instruction, input_text, output in zip(\n        examples['instruction'],\n        examples['input'],\n        examples['output']\n    ):\n        # Apply the template to each example individually\n        formatted_text = tokenizer.apply_chat_template(\n            [\n                {\"role\": \"user\", \"content\": f\"{instruction}: {input_text}\"},\n                {\"role\": \"assistant\", \"content\": output}\n            ],\n            tokenize=False,\n            add_generation_prompt=False\n        )\n        texts.append(formatted_text)\n\n    # Return all original columns plus the formatted text\n    return {\n        \"instruction\": examples[\"instruction\"],\n        \"input\": examples[\"input\"],\n        \"output\": examples[\"output\"],\n        \"text\": texts\n    }","metadata":{"id":"5w8bQdnhc9G3","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Apply the formatting while keeping original columns\nformatted_dataset = dataset.map(\n    format_examples,\n    batched=True,\n)","metadata":{"id":"5w8bQdnhc9G3","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Print an example to verify the format\nprint(\"Example of formatted conversation:\")\nprint(formatted_dataset[0]['text'])\n\n# The dataset is now ready for training\n# You can access it as formatted_dataset['text']\ntrain_dataset, val_dataset, test_dataset = random_split(formatted_dataset, [0.8, 0.1, 0.1])\n# If you need to split it into train/validation sets:\ntrain_val = formatted_dataset.train_test_split(test_size=0.1, seed=42)\ntrain_data = train_val['train']\nval_data = train_val['test']\n\nprint(\"\\nDataset sizes:\")\nprint(f\"Train: {len(train_data)}\")\nprint(f\"Validation: {len(val_data)}\")","metadata":{"id":"9acv847bdClb","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"#### Unsloth training","metadata":{"id":"GDmyCvJZHX3U"}},{"cell_type":"markdown","source":"Train the model","metadata":{"id":"5kca4Wt6dGMb"}},{"cell_type":"code","source":"# Data collator preparation for the dataset\ndata_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)\n\n# Trainer with evaluation and without reporting\ntrainer = SFTTrainer(\n    model=model,\n    tokenizer=tokenizer,\n    train_dataset=train_data,  # Training dataset\n    # eval_dataset=val_data,  # Validation dataset\n    dataset_text_field=\"text\",\n    max_seq_length=max_seq_length,\n    data_collator=data_collator,\n    dataset_num_proc=2,\n    packing=False,\n    args=TrainingArguments(\n        per_device_train_batch_size=4,  # Training batch size\n        per_device_eval_batch_size=8,  # Evaluation batch size\n        gradient_accumulation_steps=8,\n        warmup_steps=10,\n        max_steps=30,\n        learning_rate=1e-4,\n        fp16=True,  # Mixed precision\n        logging_steps=2,  # Log every 2 steps\n        # eval_strategy=\"steps\",  # Enable evaluation\n        # eval_steps=2,  # Evaluation frequency\n        save_steps=50,  # Checkpoint saving frequency\n        save_total_limit=2,  # Limit the number of saved checkpoints\n        optim=\"adamw_8bit\",\n        weight_decay=0.01,\n        lr_scheduler_type=\"cosine\",\n        seed=42,\n        output_dir=\"outputs\",\n        report_to=\"none\",  # Disable reporting\n        # load_best_model_at_end=True,  # Load the best model at the end\n        metric_for_best_model=\"eval_loss\",  # Use loss as the metric\n        greater_is_better=False,  # Lower loss is better\n    ),\n)","metadata":{"id":"-OCmnk9fdNuB","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# We use Unsloth's train_on_completions method to focus training on the assistant's responses\ntrainer = train_on_responses_only(\n    trainer,\n    instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",  # Identifier for the user's input\n    response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",  # Identifier for the assistant's output\n)","metadata":{"id":"BgiZ1Fi7dT5H","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# We verify that the masking has been correctly applied by decoding the tokenized input\ntokenizer.decode(trainer.train_dataset[5][\"input_ids\"])","metadata":{"id":"yEaUPiqzdVrS","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Get the token ID for a space\nspace_token_id = tokenizer(\" \", add_special_tokens=False).input_ids[0]\n\n# Process the labels, replacing -100 with the space token ID\nlabels = trainer.train_dataset[5][\"labels\"]\nprocessed_labels = [space_token_id if token == -100 else token for token in labels]\n\n# Decode the processed labels into text\ndecoded_text = tokenizer.decode(processed_labels)\n\nprint(decoded_text)","metadata":{"id":"f5kkR3xvdhRw","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# See the stats before training\ngpu_stats = torch.cuda.get_device_properties(0)\nstart_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\nmax_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\nprint(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\nprint(f\"{start_gpu_memory} GB of memory reserved.\")","metadata":{"id":"K_1w3r-ud7pc","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer_stats = trainer.train()","metadata":{"id":"9fB3HtLtd0-W","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"df = pd.DataFrame(trainer.state.log_history)\n\n# Plot training and validation loss\nplt.figure(figsize=(10, 6))\nplt.plot(df['loss'], label='Training Loss')\nplt.title('Training and Validation Loss')\nplt.xlabel('Steps')\nplt.ylabel('Loss')\nplt.legend()\nplt.grid(True)\nplt.savefig('train_loss_unsloth.png')\nplt.show()","metadata":{"id":"T3HYffPXc37h","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"#Show final memory and time stats\nused_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\nused_memory_for_lora = round(used_memory - start_gpu_memory, 3)\nused_percentage = round(used_memory         /max_memory*100, 3)\nlora_percentage = round(used_memory_for_lora/max_memory*100, 3)\nprint(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\nprint(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\nprint(f\"Peak reserved memory = {used_memory} GB.\")\nprint(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\nprint(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\nprint(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")","metadata":{"id":"qmAOlPjQeJRX","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Save the model\nmodel.save_pretrained(\"finetuned_unsloth_model\") # Local saving\ntokenizer.save_pretrained(\"finetuned_unsloth_model\")","metadata":{"id":"oenKzibBlSnr","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"Inference","metadata":{"id":"1Mb8Yzs9eL15"}},{"cell_type":"code","source":"# Setup tokenizer with Llama-3.1 template\n\n# Loading the fine-tuned model and the tokenizer for inference\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n        model_name = \"/content/crimson\",\n        max_seq_length = 2048,\n        dtype = None,\n        load_in_4bit = True,\n    )\n\n\n# Enable faster inference\nFastLanguageModel.for_inference(model)\n\n# Example medical question from our dataset\nmessages = [\n    {\"role\": \"user\", \"content\": \"Answer this question truthfully: What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?\"},\n]\n\n# Standard generation\ninputs = tokenizer.apply_chat_template(\n    messages,\n    tokenize = True,\n    add_generation_prompt = True,\n    return_tensors = \"pt\",\n).to(\"cuda\")\n\n\nprint(\"\\n Generation:\")\nstreamer = TextStreamer(tokenizer, skip_prompt = True)\n\n_ = model.generate(\n    input_ids = inputs,\n    streamer = streamer,\n    max_new_tokens = 64,\n    use_cache = True,\n    temperature = 0.7,\n    min_p = 0.1\n)","metadata":{"id":"_20v1bbjeLZw","trusted":true},"outputs":[],"execution_count":null}]}