[133ee8]: / trash / medllama.ipynb

Download this file

1 lines (1 with data), 161.9 kB

{"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"},"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[],"include_colab_link":true},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"sourceId":218074,"sourceType":"modelInstanceVersion","isSourceIdPinned":true,"modelInstanceId":185968,"modelId":208088}],"dockerImageVersionId":30823,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"<a href=\"https://colab.research.google.com/github/sAndreotti/MedicalMeadow/blob/main/ATML_part2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>","metadata":{"id":"view-in-github","colab_type":"text"}},{"cell_type":"markdown","source":"Import Libraries","metadata":{"id":"29nxqgohSsQ3"}},{"cell_type":"code","source":"!pip install datasets accelerate peft bitsandbytes transformers trl==0.12.0 plotly huggingface_hub\n!pip install --upgrade smart_open\n!pip install --upgrade gensim\n!pip install ffmpeg-python\n!pip install -U openai-whisper\n!pip install scipy librosa unidecode inflect","metadata":{"id":"SAho3HGib9-U","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:00:24.088105Z","iopub.execute_input":"2025-01-04T20:00:24.088491Z","iopub.status.idle":"2025-01-04T20:01:13.311397Z","shell.execute_reply.started":"2025-01-04T20:00:24.088447Z","shell.execute_reply":"2025-01-04T20:01:13.310512Z"}},"outputs":[{"name":"stdout","text":"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\nRequirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.34.2)\nCollecting peft\n  Downloading peft-0.14.0-py3-none-any.whl.metadata (13 kB)\nCollecting bitsandbytes\n  Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl.metadata (2.9 kB)\nRequirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\nCollecting trl==0.12.0\n  Downloading trl-0.12.0-py3-none-any.whl.metadata (10 kB)\nRequirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (5.24.1)\nRequirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.24.7)\nRequirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from trl==0.12.0) (13.8.1)\nCollecting transformers\n  Downloading transformers-4.47.1-py3-none-any.whl.metadata (44 kB)\n\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.1/44.1 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\nRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\nRequirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (18.1.0)\nRequirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\nRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\nRequirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\nRequirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\nRequirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\nRequirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\nRequirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.6.1)\nRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\nRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\nRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\nRequirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.4.1+cu121)\nRequirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\nCollecting huggingface_hub\n  Downloading huggingface_hub-0.27.0-py3-none-any.whl.metadata (13 kB)\nRequirement already satisfied: typing_extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (4.12.2)\nRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\nCollecting tokenizers<0.22,>=0.21 (from transformers)\n  Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\nRequirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly) (9.0.0)\nRequirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\nRequirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\nRequirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\nRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.3)\nRequirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.3)\nRequirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\nRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\nRequirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\nRequirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl==0.12.0) (3.0.0)\nRequirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->trl==0.12.0) (2.18.0)\nRequirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->trl==0.12.0) (0.1.2)\nRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\nRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\nRequirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\nDownloading trl-0.12.0-py3-none-any.whl (310 kB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.2/310.2 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading peft-0.14.0-py3-none-any.whl (374 kB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m374.8/374.8 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl (69.1 MB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m69.1/69.1 MB\u001b[0m \u001b[31m25.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hDownloading transformers-4.47.1-py3-none-any.whl (10.1 MB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m115.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m0:01\u001b[0m\n\u001b[?25hDownloading huggingface_hub-0.27.0-py3-none-any.whl (450 kB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m450.5/450.5 kB\u001b[0m \u001b[31m23.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hDownloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m84.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n\u001b[?25hInstalling collected packages: huggingface_hub, tokenizers, bitsandbytes, transformers, peft, trl\n  Attempting uninstall: huggingface_hub\n    Found existing installation: huggingface-hub 0.24.7\n    Uninstalling huggingface-hub-0.24.7:\n      Successfully uninstalled huggingface-hub-0.24.7\n  Attempting uninstall: tokenizers\n    Found existing installation: tokenizers 0.19.1\n    Uninstalling tokenizers-0.19.1:\n      Successfully uninstalled tokenizers-0.19.1\n  Attempting uninstall: transformers\n    Found existing installation: transformers 4.44.2\n    Uninstalling transformers-4.44.2:\n      Successfully uninstalled transformers-4.44.2\nSuccessfully installed bitsandbytes-0.45.0 huggingface_hub-0.27.0 peft-0.14.0 tokenizers-0.21.0 transformers-4.47.1 trl-0.12.0\nRequirement already satisfied: smart_open in /usr/local/lib/python3.10/dist-packages (7.0.4)\nCollecting smart_open\n  Downloading smart_open-7.1.0-py3-none-any.whl.metadata (24 kB)\nRequirement already satisfied: wrapt in /usr/local/lib/python3.10/dist-packages (from smart_open) (1.16.0)\nDownloading smart_open-7.1.0-py3-none-any.whl (61 kB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.7/61.7 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hInstalling collected packages: smart_open\n  Attempting uninstall: smart_open\n    Found existing installation: smart-open 7.0.4\n    Uninstalling smart-open-7.0.4:\n      Successfully uninstalled smart-open-7.0.4\nSuccessfully installed smart_open-7.1.0\nRequirement already satisfied: gensim in /usr/local/lib/python3.10/dist-packages (4.3.3)\nRequirement already satisfied: numpy<2.0,>=1.18.5 in /usr/local/lib/python3.10/dist-packages (from gensim) (1.26.4)\nRequirement already satisfied: scipy<1.14.0,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from gensim) (1.13.1)\nRequirement already satisfied: smart-open>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from gensim) (7.1.0)\nRequirement already satisfied: wrapt in /usr/local/lib/python3.10/dist-packages (from smart-open>=1.8.1->gensim) (1.16.0)\nCollecting ffmpeg-python\n  Downloading ffmpeg_python-0.2.0-py3-none-any.whl.metadata (1.7 kB)\nRequirement already satisfied: future in /usr/local/lib/python3.10/dist-packages (from ffmpeg-python) (1.0.0)\nDownloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)\nInstalling collected packages: ffmpeg-python\nSuccessfully installed ffmpeg-python-0.2.0\nCollecting openai-whisper\n  Downloading openai-whisper-20240930.tar.gz (800 kB)\n\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m800.5/800.5 kB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\nRequirement already satisfied: numba in /usr/local/lib/python3.10/dist-packages (from openai-whisper) (0.60.0)\nRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from openai-whisper) (1.26.4)\nRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from openai-whisper) (2.4.1+cu121)\nRequirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from openai-whisper) (4.66.5)\nRequirement already satisfied: more-itertools in /usr/local/lib/python3.10/dist-packages (from openai-whisper) (10.5.0)\nRequirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from openai-whisper) (0.8.0)\nCollecting triton>=2.0.0 (from openai-whisper)\n  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)\nRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from triton>=2.0.0->openai-whisper) (3.16.1)\nRequirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba->openai-whisper) (0.43.0)\nRequirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->openai-whisper) (2024.9.11)\nRequirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.10/dist-packages (from tiktoken->openai-whisper) (2.32.3)\nRequirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->openai-whisper) (4.12.2)\nRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->openai-whisper) (1.13.3)\nRequirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->openai-whisper) (3.3)\nRequirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->openai-whisper) (3.1.4)\nRequirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->openai-whisper) (2024.6.1)\nRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken->openai-whisper) (3.3.2)\nRequirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken->openai-whisper) (3.10)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken->openai-whisper) (2.2.3)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken->openai-whisper) (2024.8.30)\nRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->openai-whisper) (2.1.5)\nRequirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->openai-whisper) (1.3.0)\nDownloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.5/209.5 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hBuilding wheels for collected packages: openai-whisper\n  Building wheel for openai-whisper (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n  Created wheel for openai-whisper: filename=openai_whisper-20240930-py3-none-any.whl size=803320 sha256=c9bf255fffc128af536984c1c89aabfca237ca35aeb7150a8d4d50b82ce4f360\n  Stored in directory: /root/.cache/pip/wheels/dd/4a/1f/d1c4bf3b9133c8168fe617ed979cab7b14fe381d059ffb9d83\nSuccessfully built openai-whisper\nInstalling collected packages: triton, openai-whisper\nSuccessfully installed openai-whisper-20240930 triton-3.1.0\nRequirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.13.1)\nRequirement already satisfied: librosa in /usr/local/lib/python3.10/dist-packages (0.10.2.post1)\nCollecting unidecode\n  Downloading Unidecode-1.3.8-py3-none-any.whl.metadata (13 kB)\nRequirement already satisfied: inflect in /usr/local/lib/python3.10/dist-packages (7.4.0)\nRequirement already satisfied: numpy<2.3,>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from scipy) (1.26.4)\nRequirement already satisfied: audioread>=2.1.9 in /usr/local/lib/python3.10/dist-packages (from librosa) (3.0.1)\nRequirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.10/dist-packages (from librosa) (1.2.2)\nRequirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.10/dist-packages (from librosa) (1.4.2)\nRequirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from librosa) (4.4.2)\nRequirement already satisfied: numba>=0.51.0 in /usr/local/lib/python3.10/dist-packages (from librosa) (0.60.0)\nRequirement already satisfied: soundfile>=0.12.1 in /usr/local/lib/python3.10/dist-packages (from librosa) (0.12.1)\nRequirement already satisfied: pooch>=1.1 in /usr/local/lib/python3.10/dist-packages (from librosa) (1.8.2)\nRequirement already satisfied: soxr>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from librosa) (0.5.0.post1)\nRequirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.10/dist-packages (from librosa) (4.12.2)\nRequirement already satisfied: lazy-loader>=0.1 in /usr/local/lib/python3.10/dist-packages (from librosa) (0.4)\nRequirement already satisfied: msgpack>=1.0 in /usr/local/lib/python3.10/dist-packages (from librosa) (1.0.8)\nRequirement already satisfied: more-itertools>=8.5.0 in /usr/local/lib/python3.10/dist-packages (from inflect) (10.5.0)\nRequirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from inflect) (4.3.0)\nRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from lazy-loader>=0.1->librosa) (24.1)\nRequirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba>=0.51.0->librosa) (0.43.0)\nRequirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.10/dist-packages (from pooch>=1.1->librosa) (4.3.6)\nRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from pooch>=1.1->librosa) (2.32.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.20.0->librosa) (3.5.0)\nRequirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile>=0.12.1->librosa) (1.17.1)\nRequirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile>=0.12.1->librosa) (2.22)\nRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (3.3.2)\nRequirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (3.10)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (2.2.3)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (2024.8.30)\nDownloading Unidecode-1.3.8-py3-none-any.whl (235 kB)\n\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m235.5/235.5 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n\u001b[?25hInstalling collected packages: unidecode\nSuccessfully installed unidecode-1.3.8\n","output_type":"stream"}],"execution_count":1},{"cell_type":"code","source":"from datasets import load_dataset\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom collections import Counter\nfrom trl import SFTTrainer\nimport re\nfrom gensim.models.word2vec import Word2Vec\nimport plotly.express as px\nimport random\nfrom sklearn.manifold import TSNE\nfrom torch.utils.data import Dataset\nfrom torch.utils.data import random_split\nfrom peft import prepare_model_for_kbit_training, LoraConfig\nfrom huggingface_hub import login\nimport torch\nfrom transformers import (\n    AutoTokenizer,\n    AutoModelForCausalLM,\n    BitsAndBytesConfig,\n    TrainingArguments\n)\nfrom datasets import Dataset as HFDataset\nfrom peft import AutoPeftModelForCausalLM\nimport whisper\nfrom IPython.display import Audio","metadata":{"id":"q_JimYqjjY4S","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:13.312626Z","iopub.execute_input":"2025-01-04T20:01:13.312898Z","iopub.status.idle":"2025-01-04T20:01:50.935335Z","shell.execute_reply.started":"2025-01-04T20:01:13.312877Z","shell.execute_reply":"2025-01-04T20:01:50.934364Z"}},"outputs":[],"execution_count":2},{"cell_type":"markdown","source":"Import Libraries for audio part","metadata":{"id":"E2DDMWKGSsQ5"}},{"cell_type":"code","source":"from IPython.display import HTML, Audio\nfrom google.colab.output import eval_js\nfrom base64 import b64decode\nfrom scipy.io.wavfile import read as wav_read\nimport io\nimport ffmpeg\nimport scipy","metadata":{"id":"lkjSoQT4SsQ6","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:50.937063Z","iopub.execute_input":"2025-01-04T20:01:50.937638Z","iopub.status.idle":"2025-01-04T20:01:51.020626Z","shell.execute_reply.started":"2025-01-04T20:01:50.937613Z","shell.execute_reply":"2025-01-04T20:01:51.019801Z"}},"outputs":[],"execution_count":3},{"cell_type":"markdown","source":"## Investigate Dataset","metadata":{"id":"XKL46PcIc21t"}},{"cell_type":"code","source":"ds = load_dataset(\"medalpaca/medical_meadow_medical_flashcards\")\nds = ds['train']\nds","metadata":{"id":"YltBm7i4b0IM","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:51.021924Z","iopub.execute_input":"2025-01-04T20:01:51.022151Z","iopub.status.idle":"2025-01-04T20:01:53.773020Z","shell.execute_reply.started":"2025-01-04T20:01:51.022132Z","shell.execute_reply":"2025-01-04T20:01:53.772180Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md:   0%|          | 0.00/1.24k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"088afe3d98844de694661524e772f2a4"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"(…)l_meadow_wikidoc_medical_flashcards.json:   0%|          | 0.00/17.7M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"04cf5fd44aea4f99b09e62a124880296"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Generating train split:   0%|          | 0/33955 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2e3a0964991e4f5c8720f6100bd86fa2"}},"metadata":{}},{"execution_count":4,"output_type":"execute_result","data":{"text/plain":"Dataset({\n    features: ['input', 'output', 'instruction'],\n    num_rows: 33955\n})"},"metadata":{}}],"execution_count":4},{"cell_type":"code","source":"print(ds.features, \"\\n\")\nprint(\"Instruction:\")\nprint(f\"length: {len(ds['instruction'])}\")\nprint(f\"example: {ds['instruction'][0]} \\n\")\n\nprint(f\"Input:\")\nprint(f\"length: {len(ds['input'])}\")\nprint(f\"example: {ds['input'][0]} \\n\")\n\nprint(f\"Output:\")\nprint(f\"length: {len(ds['output'])}\")\nprint(f\"example: {ds['output'][0]} \\n\")","metadata":{"id":"n89RynbpcHQB","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:53.773944Z","iopub.execute_input":"2025-01-04T20:01:53.774241Z","iopub.status.idle":"2025-01-04T20:01:53.980508Z","shell.execute_reply.started":"2025-01-04T20:01:53.774218Z","shell.execute_reply":"2025-01-04T20:01:53.979675Z"}},"outputs":[{"name":"stdout","text":"{'input': Value(dtype='string', id=None), 'output': Value(dtype='string', id=None), 'instruction': Value(dtype='string', id=None)} \n\nInstruction:\nlength: 33955\nexample: Answer this question truthfully \n\nInput:\nlength: 33955\nexample: What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels? \n\nOutput:\nlength: 33955\nexample: Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels. \n\n","output_type":"stream"}],"execution_count":5},{"cell_type":"markdown","source":"### Some plots about the dataset","metadata":{"id":"a9SXu1UFoE4x"}},{"cell_type":"code","source":"instructions = ds['instruction']\ninput_phrases = ds['input']\noutput_phrases = ds['output']","metadata":{"trusted":true,"id":"nLEtHKtISsQ7","execution":{"iopub.status.busy":"2025-01-04T20:01:53.981418Z","iopub.execute_input":"2025-01-04T20:01:53.981739Z","iopub.status.idle":"2025-01-04T20:01:54.080089Z","shell.execute_reply.started":"2025-01-04T20:01:53.981707Z","shell.execute_reply":"2025-01-04T20:01:54.079347Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"# Count the frequency of each unique instruction\ninstruction_counts = {instruction: instructions.count(instruction) for instruction in set(instructions)}\n\n# Sort the instructions by frequency\nsorted_instructions = sorted(instruction_counts.items(), key=lambda x: x[1], reverse=True)\n\n# Separate the instructions and their counts for plotting\nsorted_instruction_names = [item[0] for item in sorted_instructions]\nsorted_instruction_counts = [item[1] for item in sorted_instructions]\n\n# Plotting the frequency of instructions\nplt.figure(figsize=(10, 5))\n\nbars = plt.barh(sorted_instruction_names, sorted_instruction_counts, color='skyblue', edgecolor='black', linewidth=1.2)\nplt.title('Instruction Frequency Distribution')\nplt.xlabel('Frequency')\nplt.ylabel('Instruction')\n\n# Show the plot\nplt.tight_layout()\nplt.show()","metadata":{"id":"4TAMV5DdnRg7","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:54.080886Z","iopub.execute_input":"2025-01-04T20:01:54.081205Z","iopub.status.idle":"2025-01-04T20:01:54.369440Z","shell.execute_reply.started":"2025-01-04T20:01:54.081174Z","shell.execute_reply":"2025-01-04T20:01:54.368576Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"<Figure size 1000x500 with 1 Axes>","image/png":"\n"},"metadata":{}}],"execution_count":7},{"cell_type":"code","source":"# Calculate the length of each phrase\ninput_lengths = [len(phrase) for phrase in input_phrases]\noutput_lengths = [len(phrase) for phrase in output_phrases]\n\n# Define the bins for the length ranges\nmax_input = max(input_lengths)\nmax_output = max(output_lengths)\n\ninput_bins = [i * max_input / 10 for i in range(1, 11)]\noutput_bins = [i * max_output / 10 for i in range(1, 11)]\nbin_labels_input = [f'{int(input_bins[i-1])}-{int(input_bins[i])}' for i in range(1, 10)]\nbin_labels_output = [f'{int(output_bins[i-1])}-{int(output_bins[i])}' for i in range(1, 10)]\n\n# Bin the lengths into the categories\ninput_binned = np.digitize(input_lengths, input_bins)  # Categorize based on input lengths\noutput_binned = np.digitize(output_lengths, output_bins)  # Categorize based on output lengths\n\n# Count how many phrases fall into each bin\ninput_bin_counts = [sum(input_binned == i) for i in range(1, len(input_bins))]\noutput_bin_counts = [sum(output_binned == i) for i in range(1, len(output_bins))]\n\n# Plotting the bar charts\nplt.figure(figsize=(20, 10))\n\n# Plotting the input phrase lengths\nplt.subplot(1, 2, 1)\nplt.bar(bin_labels_input, input_bin_counts, color='skyblue', edgecolor='black')\nplt.title('Input Phrases Length Distribution')\nplt.xlabel('Length Range')\nplt.ylabel('Number of Phrases')\n\n# Plotting the output phrase lengths\nplt.subplot(1, 2, 2)\nplt.bar(bin_labels_output, output_bin_counts, color='skyblue', edgecolor='black')\nplt.title('Output Phrases Length Distribution')\nplt.xlabel('Length Range')\nplt.ylabel('Number of Phrases')\n\n# Show the plots\nplt.tight_layout()\nplt.show()","metadata":{"id":"d9lEMELsiqIx","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:54.371925Z","iopub.execute_input":"2025-01-04T20:01:54.372191Z","iopub.status.idle":"2025-01-04T20:01:54.925931Z","shell.execute_reply.started":"2025-01-04T20:01:54.372169Z","shell.execute_reply":"2025-01-04T20:01:54.925075Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"<Figure size 2000x1000 with 2 Axes>","image/png":"\n"},"metadata":{}}],"execution_count":8},{"cell_type":"markdown","source":"## Train and evaluate models","metadata":{"id":"-BMh1vbuc5qp"}},{"cell_type":"markdown","source":"#### Create dataset","metadata":{"id":"VFt5KLR-8w3p"}},{"cell_type":"code","source":"# class MedDataset(Dataset):\n#   def __init__(self, instruction, input, output):\n#     self.instruction = instruction\n#     self.input = input\n#     self.output = output\n\n#   def __len__(self):\n#     return len(self.instruction)\n\n#   def __getitem__(self, idx):\n#     sentence = \"<s>[INST] \"+self.instruction[idx]+\". \"+self.input[idx]+\" [/INST] \"+self.output[idx]+\" </s>\"\n#     return sentence","metadata":{"id":"ND0M0R633VbB","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:54.927407Z","iopub.execute_input":"2025-01-04T20:01:54.927665Z","iopub.status.idle":"2025-01-04T20:01:54.931019Z","shell.execute_reply.started":"2025-01-04T20:01:54.927623Z","shell.execute_reply":"2025-01-04T20:01:54.930106Z"}},"outputs":[],"execution_count":9},{"cell_type":"code","source":"class MedDataset(Dataset):\n    def __init__(self, dataset, tokenizer):\n        self.dataset = dataset\n        self.tokenizer = tokenizer\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, idx):\n        example = self.dataset[idx]\n        messages = [\n            {\"role\": \"system\", \"content\": example['instruction']},\n            {\"role\": \"user\", \"content\": example['input']},\n            {\"role\": \"assistant\", \"content\": example['output']}\n        ]\n\n        prompt = self.tokenizer.apply_chat_template(\n            messages,\n            tokenize=False,\n            add_generation_prompt=True\n        )\n\n        tokens = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=128,\n            return_tensors=\"pt\"\n        )\n\n        tokens['labels'] = tokens['input_ids'].clone()\n        tokens['labels'][tokens['input_ids'] == self.tokenizer.pad_token_id] = -100\n\n        return {\n            \"input_ids\": tokens['input_ids'].squeeze(),\n            \"attention_mask\": tokens['attention_mask'].squeeze(),\n            \"labels\": tokens['labels'].squeeze()\n        }","metadata":{"trusted":true,"id":"LacKbkPHSsQ-","execution":{"iopub.status.busy":"2025-01-04T20:01:54.931796Z","iopub.execute_input":"2025-01-04T20:01:54.932121Z","iopub.status.idle":"2025-01-04T20:01:54.948399Z","shell.execute_reply.started":"2025-01-04T20:01:54.932093Z","shell.execute_reply":"2025-01-04T20:01:54.947523Z"}},"outputs":[],"execution_count":10},{"cell_type":"code","source":"login(token=\"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\")\n\nbase_model = \"meta-llama/Llama-3.2-1B-Instruct\"\n\ntokenizer = AutoTokenizer.from_pretrained(base_model)\ntokenizer.pad_token = tokenizer.eos_token\n\n# tokenizer.padding_side = \"right\"","metadata":{"trusted":true,"id":"AelKMq1xSsQ-","execution":{"iopub.status.busy":"2025-01-04T20:01:54.949192Z","iopub.execute_input":"2025-01-04T20:01:54.949444Z","iopub.status.idle":"2025-01-04T20:01:57.117672Z","shell.execute_reply.started":"2025-01-04T20:01:54.949412Z","shell.execute_reply":"2025-01-04T20:01:57.117006Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"380c38eda12f42aea18182a1c6eaf240"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"a5a6ceb82afc4aac96adcfa66ee0f1d4"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"408dd0b1adad40b7930bcf25ba0fae9d"}},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"tokenized_dataset = MedDataset(ds, tokenizer)","metadata":{"id":"C6gCSyFM7IsG","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:57.118459Z","iopub.execute_input":"2025-01-04T20:01:57.118727Z","iopub.status.idle":"2025-01-04T20:01:57.122233Z","shell.execute_reply.started":"2025-01-04T20:01:57.118700Z","shell.execute_reply":"2025-01-04T20:01:57.121569Z"}},"outputs":[],"execution_count":12},{"cell_type":"code","source":"train_dataset, val_dataset, test_dataset = random_split(tokenized_dataset, [0.8, 0.1, 0.1])\nprint(f\"Train dataset dimension: {len(train_dataset)}\")\nprint(f\"Validation dataset dimension: {len(val_dataset)}\")\nprint(f\"Test dataset dimension: {len(test_dataset)}\")","metadata":{"id":"RRneCjBj9Lag","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:57.123113Z","iopub.execute_input":"2025-01-04T20:01:57.123385Z","iopub.status.idle":"2025-01-04T20:01:57.161765Z","shell.execute_reply.started":"2025-01-04T20:01:57.123365Z","shell.execute_reply":"2025-01-04T20:01:57.161155Z"}},"outputs":[{"name":"stdout","text":"Train dataset dimension: 27165\nValidation dataset dimension: 3395\nTest dataset dimension: 3395\n","output_type":"stream"}],"execution_count":13},{"cell_type":"code","source":"compute_dtype = getattr(torch, \"float16\")\n\nquant_config = BitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_quant_type=\"nf4\",\n    bnb_4bit_compute_dtype=compute_dtype,\n    bnb_4bit_use_double_quant=False,\n    bnb_4bit_representation=\"nested\"\n)\n\nmodel = AutoModelForCausalLM.from_pretrained(\n    base_model,\n    quantization_config=quant_config,\n    device_map={\"\": 0},\n    torch_dtype=torch.float32,\n    trust_remote_code=True\n)\nmodel.config.use_cache = False\nmodel.config.pretraining_tp = 1\n\nmodel.gradient_checkpointing_enable()\nmodel = prepare_model_for_kbit_training(model)","metadata":{"id":"gR4BtF8AHD0I","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:01:57.162479Z","iopub.execute_input":"2025-01-04T20:01:57.162649Z","iopub.status.idle":"2025-01-04T20:03:01.840482Z","shell.execute_reply.started":"2025-01-04T20:01:57.162634Z","shell.execute_reply":"2025-01-04T20:03:01.839883Z"}},"outputs":[{"name":"stderr","text":"Unused kwargs: ['bnb_4bit_representation']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"6cec6b8d4209431d8036028764183634"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9d13c5397e6442be92af2c72c1b48ae0"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"7ccdec106f714145acd8115af2d1ea20"}},"metadata":{}}],"execution_count":14},{"cell_type":"code","source":"# def convert_to_hf_dataset(med_dataset):\n#     # Create lists to store all formatted text\n#     formatted_texts = []\n\n#     # Iterate through all items in the original dataset\n#     for idx in range(len(med_dataset.instruction)):\n#         # Get the formatted text directly using the dataset's __getitem__\n#         formatted_text = med_dataset[idx]\n#         formatted_texts.append(formatted_text)\n\n#     # Create a dictionary with the required format\n#     dataset_dict = {\n#         'text': formatted_texts\n#     }\n\n#     # Convert to HuggingFace Dataset\n#     hf_dataset = HFDataset.from_dict(dataset_dict)\n\n#     return hf_dataset\n\n# hf_dataset = convert_to_hf_dataset(garnachoDataset)","metadata":{"id":"gzL6yXsGSsQ_","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:01.841346Z","iopub.execute_input":"2025-01-04T20:03:01.841655Z","iopub.status.idle":"2025-01-04T20:03:01.844971Z","shell.execute_reply.started":"2025-01-04T20:03:01.841625Z","shell.execute_reply":"2025-01-04T20:03:01.844214Z"}},"outputs":[],"execution_count":15},{"cell_type":"code","source":"peft_params = LoraConfig(\n    lora_alpha=32,\n    lora_dropout=0.1,\n    r=16,\n    bias=\"none\",\n    task_type=\"CAUSAL_LM\",\n)","metadata":{"id":"ew98HuwaKW1P","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:01.845643Z","iopub.execute_input":"2025-01-04T20:03:01.845854Z","iopub.status.idle":"2025-01-04T20:03:01.862789Z","shell.execute_reply.started":"2025-01-04T20:03:01.845824Z","shell.execute_reply":"2025-01-04T20:03:01.862086Z"}},"outputs":[],"execution_count":16},{"cell_type":"code","source":"model.train()\ntraining_params = TrainingArguments(\n    output_dir=\"./results\",\n    num_train_epochs=1,\n    per_device_train_batch_size=8,\n    gradient_accumulation_steps=4,\n    optim=\"paged_adamw_32bit\",\n    eval_strategy=\"steps\",\n    logging_steps=90,\n    eval_steps = 90,\n    learning_rate=2e-4,\n    weight_decay=0.001,\n    fp16=False,\n    bf16=False,\n    max_grad_norm=0.3,\n    max_steps=-1,\n    warmup_ratio=0.03,\n    group_by_length=True,\n    lr_scheduler_type=\"constant\",\n    report_to=\"tensorboard\",\n    gradient_checkpointing=True\n)","metadata":{"id":"n8dBHX-j-M1J","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:01.863493Z","iopub.execute_input":"2025-01-04T20:03:01.863774Z","iopub.status.idle":"2025-01-04T20:03:01.908617Z","shell.execute_reply.started":"2025-01-04T20:03:01.863754Z","shell.execute_reply":"2025-01-04T20:03:01.907842Z"}},"outputs":[],"execution_count":17},{"cell_type":"code","source":"# trainer = SFTTrainer(\n#     model=model,\n#     train_dataset=train_dataset,\n#     eval_dataset=val_dataset,\n#     peft_config=peft_params,\n#     max_seq_length=256,\n#     tokenizer=tokenizer,\n#     args=training_params,\n#     packing=False,\n# )\n\n# # Train the model\n# trainer.train()\n\n# # Save the model and tokenizer\n# trainer.save_model(\"./fine-tuned-model\")\n# tokenizer.save_pretrained(\"./fine-tuned-model\")","metadata":{"trusted":true,"id":"ARnpWrEFSsQ_","execution":{"iopub.status.busy":"2025-01-04T20:03:01.909329Z","iopub.execute_input":"2025-01-04T20:03:01.909543Z","iopub.status.idle":"2025-01-04T20:03:01.912762Z","shell.execute_reply.started":"2025-01-04T20:03:01.909523Z","shell.execute_reply":"2025-01-04T20:03:01.911947Z"}},"outputs":[],"execution_count":18},{"cell_type":"code","source":"# trainer.model.save_pretrained(\"model-chatbot-medical-mew\")\n# trainer.tokenizer.save_pretrained(\"model-chatbot-medical-mew\")","metadata":{"id":"awbFBy8v5rDM","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:01.913577Z","iopub.execute_input":"2025-01-04T20:03:01.913868Z","iopub.status.idle":"2025-01-04T20:03:01.931446Z","shell.execute_reply.started":"2025-01-04T20:03:01.913840Z","shell.execute_reply":"2025-01-04T20:03:01.930682Z"}},"outputs":[],"execution_count":19},{"cell_type":"markdown","source":"## Test Trained Model","metadata":{"id":"MmC39wTYSsQ_"}},{"cell_type":"markdown","source":"### Load pre-trained model","metadata":{"id":"UwQNV8XLYoLX"}},{"cell_type":"code","source":"trained_model = \"/kaggle/input/medicalllm/pytorch/default/1/model-chatbot-medical-mew\"\nquestion = \"What does low Mobility suggest?\"\n\nmodel = AutoPeftModelForCausalLM.from_pretrained(\n    trained_model, # change with folder where u have the files\n    quantization_config=quant_config,\n    device_map={\"\": 0},\n    torch_dtype=torch.float32,\n    trust_remote_code=True\n)\ntokenizer = AutoTokenizer.from_pretrained(trained_model)","metadata":{"id":"A-XhCQC1YhLP","trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:01.932192Z","iopub.execute_input":"2025-01-04T20:03:01.932400Z","iopub.status.idle":"2025-01-04T20:03:08.053881Z","shell.execute_reply.started":"2025-01-04T20:03:01.932383Z","shell.execute_reply":"2025-01-04T20:03:08.053186Z"}},"outputs":[],"execution_count":20},{"cell_type":"code","source":"messages = [{\"role\": \"system\", \"content\": instructions[0]},\n    {\"role\": \"user\", \"content\": question}]\n\nprompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n\nmodel_inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to(\"cuda\")\n\noutputs = model.generate(**model_inputs, max_new_tokens=128)\n\ntext = tokenizer.decode(outputs[0], skip_special_tokens=True)\n\nprint(f\"Question: {question}\")\nprint(text.split(\"assistant\")[1])","metadata":{"trusted":true,"id":"NL-rYUPJSsRD","execution":{"iopub.status.busy":"2025-01-04T20:03:08.054650Z","iopub.execute_input":"2025-01-04T20:03:08.054906Z","iopub.status.idle":"2025-01-04T20:03:11.056335Z","shell.execute_reply.started":"2025-01-04T20:03:08.054885Z","shell.execute_reply":"2025-01-04T20:03:11.055341Z"}},"outputs":[{"name":"stderr","text":"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n","output_type":"stream"},{"name":"stdout","text":"Question: What does low Mobility suggest?\n\n\nLow Mobility is a medical condition that is characterized by a decrease in the range of motion of the joints, making it difficult to move the affected joints. This can be due to a variety of factors, including injury, degenerative changes, or other medical conditions such as arthritis.\n","output_type":"stream"}],"execution_count":21},{"cell_type":"markdown","source":"### Compare with MedLlama","metadata":{}},{"cell_type":"code","source":"# tokenizerMED = AutoTokenizer.from_pretrained(\"OpenScienceReseach/Med-LLaMA-7b\")\n# modelMED = AutoModelForCausalLM.from_pretrained(\"OpenScienceReseach/Med-LLaMA-7b\") funzia ma lento\n\ntokenizerMED = AutoTokenizer.from_pretrained(\"medalpaca/medalpaca-7b\")\nmodelMED = AutoModelForCausalLM.from_pretrained(\"medalpaca/medalpaca-7b\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:36.977452Z","iopub.execute_input":"2025-01-04T20:03:36.977757Z","iopub.status.idle":"2025-01-04T20:06:32.214695Z","shell.execute_reply.started":"2025-01-04T20:03:36.977734Z","shell.execute_reply":"2025-01-04T20:06:32.213870Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json:   0%|          | 0.00/260 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"69dd602f6d4742e3ae56b677e2f83d8f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"2b82cdfbb59641959f9a4becd177b5cf"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"de248b7ac03c471889376937cd231570"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"983002136f7043629d00997ca858c588"}},"metadata":{}},{"name":"stderr","text":"You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message\nYou are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"config.json:   0%|          | 0.00/542 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"0325bad4d3bc47c392f0f488e5204767"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model.safetensors.index.json:   0%|          | 0.00/28.1k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"3389942be52a4b8ebb23347bb9ba8465"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"3fc80bcc62d741b2bbaca7bb2312fdfe"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00001-of-00003.safetensors:   0%|          | 0.00/9.88G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"de5b6b1da959422c8bf04ec7ca6178d6"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00002-of-00003.safetensors:   0%|          | 0.00/9.89G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9376d933ee1d40cf93564d3c789dc853"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"model-00003-of-00003.safetensors:   0%|          | 0.00/7.18G [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5dedd23938564e4a84c24cc4bdff83d0"}},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:606: UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pad_token_id` explicitly as `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation\n  warnings.warn(\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"63bf81e806bc4737953d741452affecb"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"d16c0053c8fc45a4871d4100311b8866"}},"metadata":{}}],"execution_count":25},{"cell_type":"code","source":"!pip install nltk\nimport nltk\nnltk.download('punkt')\n\nfrom nltk.translate.bleu_score import sentence_bleu\n\ndef calculate_bleu(reference, candidate):\n    \"\"\"\n    Comput BLUE score between the candiate (generate answer) and the reference answer\n    \"\"\"\n    try:\n      score = sentence_bleu(reference, candidate)\n      return score\n    except Exception as e:\n      print(f\"Error during BLUE computing: {e}\")\n      return None","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:03:11.063599Z","iopub.execute_input":"2025-01-04T20:03:11.063875Z","iopub.status.idle":"2025-01-04T20:03:14.761490Z","shell.execute_reply.started":"2025-01-04T20:03:11.063847Z","shell.execute_reply":"2025-01-04T20:03:14.760443Z"}},"outputs":[{"name":"stderr","text":"/usr/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n  pid, fd = os.forkpty()\n","output_type":"stream"},{"name":"stdout","text":"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.2.4)\nRequirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from nltk) (1.16.0)\n[nltk_data] Downloading package punkt to /usr/share/nltk_data...\n[nltk_data]   Package punkt is already up-to-date!\n","output_type":"stream"}],"execution_count":23},{"cell_type":"code","source":"questions = [\"What does low REM sleep latency and experiencing hallucinations/sleep paralysis suggest?\",\n             \"What are some possible causes of low PTH and high calcium levels?\",\n             \"What is the term used to describe a condition of low sodium levels and very high proteins or lipids?\"]\n\ndataset_answer = [\"Low REM sleep latency and experiencing hallucinations/sleep paralysis suggests narcolepsy.\",\n                  \"PTH-independent hypercalcemia, which can be caused by cancer, granulomatous disease, or vitamin D intoxication.\",\n                  \"The term used to describe a condition of low sodium levels and very high proteins or lipids is pseudohyponatremia.\"]\n\nllama_bleu = []\nour_bleu = []\n\nfrom transformers import pipeline\nimport time\n\nfor i, question in enumerate(questions):\n    print(f\"Question: {question}\")\n    start = time.time()    \n    # llama\n    pl = pipeline(\"text-generation\", model=modelMED, tokenizer=tokenizerMED)\n    response = pl(question)\n    end = time.time() - start\n    print(\"time for generate response: \", end)\n    print(\"Response MedLlama:\", response)\n\n    candidate = response.split()\n    bleu_score = calculate_bleu(dataset_answer[i], candidate)\n\n    if bleu_score is not None:\n        print(f\"BLEU score for Llama answer {i}: {bleu_score}\")   \n        llama_bleu.append(bleu_score)\n\n\nfor i, question in enumerate(questions):\n    # our\n    messages = [{\"role\": \"system\", \"content\": instructions[0]},\n    {\"role\": \"user\", \"content\": question}]\n    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n    model_inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to(\"cuda\")\n    outputs = model.generate(**model_inputs, max_new_tokens=128)\n    text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n    our = text.split(\"assistant\")[1]\n    print(\"Our Response: \", our)\n\n    candidate = our.split()\n    bleu_score = calculate_bleu(dataset_answer[i], candidate)\n\n    if bleu_score is not None:\n        print(f\"BLEU score for our answer {i}: {bleu_score}\")   \n        our_bleu.append(bleu_score)\n\nimport numpy\nprint(f\"Average BLEU score for our model: {np.median(our_bleu)}\")\nprint(f\"Average BLEU score for Llama model: {np.median(llama_bleu)}\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-04T20:07:28.125474Z","iopub.execute_input":"2025-01-04T20:07:28.125797Z","iopub.status.idle":"2025-01-04T20:08:25.843847Z","shell.execute_reply.started":"2025-01-04T20:07:28.125773Z","shell.execute_reply":"2025-01-04T20:08:25.842457Z"}},"outputs":[{"name":"stderr","text":"Device set to use cuda:0\n","output_type":"stream"},{"name":"stdout","text":"Question: What does low REM sleep latency and experiencing hallucinations/sleep paralysis suggest?\n","output_type":"stream"},{"traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)","\u001b[0;32m<ipython-input-26-ad55525d10d4>\u001b[0m in \u001b[0;36m<cell line: 15>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0;31m# llama\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m     \u001b[0mpl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpipeline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"text-generation\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodelMED\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtokenizerMED\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     20\u001b[0m     \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquestion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m     \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/pipelines/__init__.py\u001b[0m in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, image_processor, processor, framework, revision, use_fast, token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m   1176\u001b[0m         \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"processor\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocessor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1178\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mpipeline_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframework\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mframework\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/pipelines/text_generation.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     97\u001b[0m         self.check_model_type(\n\u001b[1;32m     98\u001b[0m             \u001b[0mTF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframework\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"tf\"\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mMODEL_FOR_CAUSAL_LM_MAPPING_NAMES\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, tokenizer, feature_extractor, image_processor, processor, modelcard, framework, task, args_parser, device, torch_dtype, binary_output, **kwargs)\u001b[0m\n\u001b[1;32m    924\u001b[0m             \u001b[0;32mand\u001b[0m \u001b[0mhf_device_map\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    925\u001b[0m         ):\n\u001b[0;32m--> 926\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    927\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    928\u001b[0m         \u001b[0;31m# If the model can generate, create a local generation config. This is done to avoid side-effects on the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   3162\u001b[0m                     \u001b[0;34m\" `dtype` by passing the correct `torch_dtype` argument.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3163\u001b[0m                 )\n\u001b[0;32m-> 3164\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3166\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mhalf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1172\u001b[0m                     \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1174\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1175\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1176\u001b[0m     def register_full_backward_pre_hook(\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    778\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 780\u001b[0;31m                 \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    781\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    782\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    778\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 780\u001b[0;31m                 \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    781\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    782\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    778\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 780\u001b[0;31m                 \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    781\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    782\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    778\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 780\u001b[0;31m                 \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    781\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    782\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    778\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    779\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 780\u001b[0;31m                 \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    781\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    782\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m    803\u001b[0m             \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    804\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 805\u001b[0;31m                 \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    806\u001b[0m             \u001b[0mp_should_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    807\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m   1158\u001b[0m                         \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1159\u001b[0m                     )\n\u001b[0;32m-> 1160\u001b[0;31m                 return t.to(\n\u001b[0m\u001b[1;32m   1161\u001b[0m                     \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1162\u001b[0m                     \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 172.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 74.12 MiB is free. Process 2351 has 14.67 GiB memory in use. Of the allocated memory 14.47 GiB is allocated by PyTorch, and 75.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"],"ename":"OutOfMemoryError","evalue":"CUDA out of memory. Tried to allocate 172.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 74.12 MiB is free. Process 2351 has 14.67 GiB memory in use. Of the allocated memory 14.47 GiB is allocated by PyTorch, and 75.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)","output_type":"error"}],"execution_count":26}]}