[404218]: / Code / LangChain / Code Assistant / gpt4o_code_assistant, kkawchak.ipynb

Download this file

1043 lines (1042 with data), 64.2 kB

{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "16dc0e41-80bd-4453-b421-dcf315741bf4",
      "metadata": {
        "id": "16dc0e41-80bd-4453-b421-dcf315741bf4"
      },
      "source": [
        "# Code generation with flow\n",
        "\n",
        "AlphaCodium presented an approach for code generation that uses control flow.\n",
        "\n",
        "Main idea: [construct an answer to a coding question iteratively.](https://x.com/karpathy/status/1748043513156272416?s=20).\n",
        "\n",
        "[AlphaCodium](https://github.com/Codium-ai/AlphaCodium) iteravely tests and improves an answer on public and AI-generated tests for a particular question.\n",
        "\n",
        "We will implement some of these ideas from scratch using [LangGraph](https://python.langchain.com/docs/langgraph):\n",
        "\n",
        "1. We start with a set of documentation specified by a user\n",
        "2. We use a long context LLM to ingest it, and answer a question based upon it\n",
        "3. We perform two unit tests: Check imports and code execution\n",
        "\n",
        "![Screenshot 2024-04-03 at 1.54.25 PM.png](attachment:a7336696-1044-4f8a-b4c6-79d3eb619cd7.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "e3900420",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "e3900420",
        "outputId": "bc507e14-9c66-4d79-b1a3-916c05d73a43"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting langchain_community\n",
            "  Downloading langchain_community-0.0.38-py3-none-any.whl (2.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting langchain-openai\n",
            "  Downloading langchain_openai-0.1.7-py3-none-any.whl (34 kB)\n",
            "Collecting langchain-anthropic\n",
            "  Downloading langchain_anthropic-0.1.12-py3-none-any.whl (16 kB)\n",
            "Collecting langchain\n",
            "  Downloading langchain-0.1.20-py3-none-any.whl (1.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting langgraph\n",
            "  Downloading langgraph-0.0.49-py3-none-any.whl (77 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.6/77.6 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting bs4\n",
            "  Downloading bs4-0.0.2-py2.py3-none-any.whl (1.2 kB)\n",
            "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (6.0.1)\n",
            "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.0.30)\n",
            "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (3.9.5)\n",
            "Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)\n",
            "  Downloading dataclasses_json-0.6.6-py3-none-any.whl (28 kB)\n",
            "Collecting langchain-core<0.2.0,>=0.1.52 (from langchain_community)\n",
            "  Downloading langchain_core-0.1.52-py3-none-any.whl (302 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.9/302.9 kB\u001b[0m \u001b[31m16.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting langsmith<0.2.0,>=0.1.0 (from langchain_community)\n",
            "  Downloading langsmith-0.1.59-py3-none-any.whl (121 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.2/121.2 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (1.25.2)\n",
            "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.31.0)\n",
            "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (8.3.0)\n",
            "Collecting openai<2.0.0,>=1.24.0 (from langchain-openai)\n",
            "  Downloading openai-1.30.1-py3-none-any.whl (320 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.6/320.6 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting tiktoken<1,>=0.7 (from langchain-openai)\n",
            "  Downloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m28.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting anthropic<1,>=0.23.0 (from langchain-anthropic)\n",
            "  Downloading anthropic-0.25.9-py3-none-any.whl (871 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m871.1/871.1 kB\u001b[0m \u001b[31m30.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: defusedxml<0.8.0,>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from langchain-anthropic) (0.7.1)\n",
            "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3)\n",
            "Collecting langchain-text-splitters<0.1,>=0.0.1 (from langchain)\n",
            "  Downloading langchain_text_splitters-0.0.2-py3-none-any.whl (23 kB)\n",
            "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.7.1)\n",
            "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from bs4) (4.12.3)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.3.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (23.2.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.4.1)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (6.0.5)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.9.4)\n",
            "Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (3.7.1)\n",
            "Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.7.0)\n",
            "Collecting httpx<1,>=0.23.0 (from anthropic<1,>=0.23.0->langchain-anthropic)\n",
            "  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.3.1)\n",
            "Requirement already satisfied: tokenizers>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (0.19.1)\n",
            "Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (4.11.0)\n",
            "Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)\n",
            "  Downloading marshmallow-3.21.2-py3-none-any.whl (49 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.3/49.3 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)\n",
            "  Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n",
            "Collecting jsonpatch<2.0,>=1.33 (from langchain-core<0.2.0,>=0.1.52->langchain_community)\n",
            "  Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n",
            "Collecting packaging<24.0,>=23.2 (from langchain-core<0.2.0,>=0.1.52->langchain_community)\n",
            "  Downloading packaging-23.2-py3-none-any.whl (53 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.0/53.0 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.0->langchain_community)\n",
            "  Downloading orjson-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (142 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m142.5/142.5 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai<2.0.0,>=1.24.0->langchain-openai) (4.66.4)\n",
            "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
            "Requirement already satisfied: pydantic-core==2.18.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.18.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.7)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2024.2.2)\n",
            "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain_community) (3.0.3)\n",
            "Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken<1,>=0.7->langchain-openai) (2023.12.25)\n",
            "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->bs4) (2.5)\n",
            "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->anthropic<1,>=0.23.0->langchain-anthropic) (1.2.1)\n",
            "Collecting httpcore==1.* (from httpx<1,>=0.23.0->anthropic<1,>=0.23.0->langchain-anthropic)\n",
            "  Downloading httpcore-1.0.5-py3-none-any.whl (77 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->anthropic<1,>=0.23.0->langchain-anthropic)\n",
            "  Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain-core<0.2.0,>=0.1.52->langchain_community)\n",
            "  Downloading jsonpointer-2.4-py2.py3-none-any.whl (7.8 kB)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (0.20.3)\n",
            "Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain_community)\n",
            "  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (3.14.0)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (2023.6.0)\n",
            "Installing collected packages: packaging, orjson, mypy-extensions, jsonpointer, h11, typing-inspect, tiktoken, marshmallow, jsonpatch, httpcore, bs4, langsmith, httpx, dataclasses-json, openai, langchain-core, anthropic, langgraph, langchain-text-splitters, langchain-openai, langchain_community, langchain-anthropic, langchain\n",
            "  Attempting uninstall: packaging\n",
            "    Found existing installation: packaging 24.0\n",
            "    Uninstalling packaging-24.0:\n",
            "      Successfully uninstalled packaging-24.0\n",
            "Successfully installed anthropic-0.25.9 bs4-0.0.2 dataclasses-json-0.6.6 h11-0.14.0 httpcore-1.0.5 httpx-0.27.0 jsonpatch-1.33 jsonpointer-2.4 langchain-0.1.20 langchain-anthropic-0.1.12 langchain-core-0.1.52 langchain-openai-0.1.7 langchain-text-splitters-0.0.2 langchain_community-0.0.38 langgraph-0.0.49 langsmith-0.1.59 marshmallow-3.21.2 mypy-extensions-1.0.0 openai-1.30.1 orjson-3.10.3 packaging-23.2 tiktoken-0.7.0 typing-inspect-0.9.0\n"
          ]
        }
      ],
      "source": [
        "! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "os.environ['OPENAI_API_KEY'] = ''\n",
        "os.environ['ANTHROPIC_API_KEY'] = ''\n",
        "\n",
        "# Optional, add tracing in LangSmith\n",
        "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
        "os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'\n",
        "os.environ[\"LANGCHAIN_PROJECT\"] = \"langgraph_code_assistant\"\n",
        "os.environ['LANGCHAIN_API_KEY'] = ''"
      ],
      "metadata": {
        "id": "thz4d-ctMrIC"
      },
      "id": "thz4d-ctMrIC",
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "id": "38330223-d8c8-4156-82b6-93e63343bc01",
      "metadata": {
        "id": "38330223-d8c8-4156-82b6-93e63343bc01"
      },
      "source": [
        "## Docs\n",
        "\n",
        "Load [LangChain Expression Language](https://python.langchain.com/docs/expression_language/) (LCEL) docs as an example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "c2eb35d1-4990-47dc-a5c4-208bae588a82",
      "metadata": {
        "id": "c2eb35d1-4990-47dc-a5c4-208bae588a82"
      },
      "outputs": [],
      "source": [
        "from bs4 import BeautifulSoup as Soup\n",
        "from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader\n",
        "\n",
        "# LCEL docs\n",
        "url = \"https://github.com/kevinkawchak/Medical-Quantum-Machine-Learning/tree/main/Code/Drug%20Discovery/Meta-Llama-3\"\n",
        "loader = RecursiveUrlLoader(\n",
        "    url=url, max_depth=20, extractor=lambda x: Soup(x, \"html.parser\").text\n",
        ")\n",
        "docs = loader.load()\n",
        "\n",
        "# Sort the list based on the URLs and get the text\n",
        "d_sorted = sorted(docs, key=lambda x: x.metadata[\"source\"])\n",
        "d_reversed = list(reversed(d_sorted))\n",
        "concatenated_content = \"\\n\\n\\n --- \\n\\n\\n\".join(\n",
        "    [doc.page_content for doc in d_reversed]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "662d4ff4-1709-412f-bfed-5eb2b8d3d3dc",
      "metadata": {
        "id": "662d4ff4-1709-412f-bfed-5eb2b8d3d3dc"
      },
      "source": [
        "## LLMs\n",
        "\n",
        "### Code solution\n",
        "\n",
        "Try OpenAI and [Claude3](https://docs.anthropic.com/claude/docs/models-overview) with function calling."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "3ba3df70-f6b4-4ea5-a210-e10944960bc6",
      "metadata": {
        "id": "3ba3df70-f6b4-4ea5-a210-e10944960bc6"
      },
      "outputs": [],
      "source": [
        "from langchain_openai import ChatOpenAI\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from langchain_core.pydantic_v1 import BaseModel, Field\n",
        "\n",
        "### OpenAI\n",
        "\n",
        "# Grader prompt\n",
        "code_gen_prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\n",
        "            \"system\",\n",
        "            \"\"\"You are a coding assistant with expertise in Python. \\n\n",
        "    Here is a full set of documentation:  \\n ------- \\n  {context} \\n ------- \\n Answer the user\n",
        "    question based on the above provided documentation. Ensure any code you provide can be executed \\n\n",
        "    with all required imports and variables defined. Structure your answer with a description of the code solution. \\n\n",
        "    Then list the imports. And finally list the functioning code block. Here is the user question:\"\"\",\n",
        "        ),\n",
        "        (\"placeholder\", \"{messages}\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "\n",
        "# Data model\n",
        "class code(BaseModel):\n",
        "    \"\"\"Code output\"\"\"\n",
        "\n",
        "    prefix: str = Field(description=\"Description of the problem and approach\")\n",
        "    imports: str = Field(description=\"Code block import statements\")\n",
        "    code: str = Field(description=\"Code block not including import statements\")\n",
        "    description = \"Schema for code solutions to questions about Python.\"\n",
        "\n",
        "\n",
        "expt_llm = \"gpt-4o-2024-05-13\"\n",
        "llm = ChatOpenAI(temperature=0, model=expt_llm)\n",
        "code_gen_chain = code_gen_prompt | llm.with_structured_output(code)\n",
        "question = \"How do I Fine-tune Llama 3 with the Mol-Instructions dataset?\"\n",
        "# solution = code_gen_chain_oai.invoke({\"context\":concatenated_content,\"messages\":[(\"user\",question)]})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "cd30b67d-96db-4e51-a540-ae23fcc1f878",
      "metadata": {
        "id": "cd30b67d-96db-4e51-a540-ae23fcc1f878"
      },
      "outputs": [],
      "source": [
        "# from langchain_anthropic import ChatAnthropic\n",
        "# from langchain_core.prompts import ChatPromptTemplate\n",
        "# from langchain_core.pydantic_v1 import BaseModel, Field\n",
        "\n",
        "# ### Anthropic\n",
        "\n",
        "# # Prompt to enforce tool use\n",
        "# code_gen_prompt_claude = ChatPromptTemplate.from_messages(\n",
        "#     [\n",
        "#         (\n",
        "#             \"system\",\n",
        "#             \"\"\"<instructions> You are a coding assistant with expertise in Python. \\n\n",
        "#     Here is the documentation:  \\n ------- \\n  {context} \\n ------- \\n Answer the user  question based on the \\n\n",
        "#     above provided documentation. Ensure any code you provide can be executed with all required imports and variables \\n\n",
        "#     defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \\n\n",
        "#     Invoke the code tool to structure the output correctly. </instructions> \\n Here is the user question:\"\"\",\n",
        "#         ),\n",
        "#         (\"placeholder\", \"{messages}\"),\n",
        "#     ]\n",
        "# )\n",
        "\n",
        "\n",
        "# # Data model\n",
        "# class code(BaseModel):\n",
        "#     \"\"\"Code output\"\"\"\n",
        "\n",
        "#     prefix: str = Field(description=\"Description of the problem and approach\")\n",
        "#     imports: str = Field(description=\"Code block import statements\")\n",
        "#     code: str = Field(description=\"Code block not including import statements\")\n",
        "#     description = \"Schema for code solutions to questions about Python.\"\n",
        "\n",
        "\n",
        "# # LLM\n",
        "# # expt_llm = \"claude-3-haiku-20240307\"\n",
        "# expt_llm = \"claude-3-opus-20240229\"\n",
        "# llm = ChatAnthropic(\n",
        "#     model=expt_llm,\n",
        "#     default_headers={\"anthropic-beta\": \"tools-2024-04-04\"},\n",
        "# )\n",
        "\n",
        "# structured_llm_claude = llm.with_structured_output(code, include_raw=True)\n",
        "\n",
        "\n",
        "# # Optional: Check for errors in case tool use is flaky\n",
        "# def check_claude_output(tool_output):\n",
        "#     \"\"\"Check for parse error or failure to call the tool\"\"\"\n",
        "\n",
        "#     # Error with parsing\n",
        "#     if tool_output[\"parsing_error\"]:\n",
        "#         # Report back output and parsing errors\n",
        "#         print(\"Parsing error!\")\n",
        "#         raw_output = str(code_output[\"raw\"].content)\n",
        "#         error = tool_output[\"parsing_error\"]\n",
        "#         raise ValueError(\n",
        "#             f\"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \\n Parse error: {error}\"\n",
        "#         )\n",
        "\n",
        "#     # Tool was not invoked\n",
        "#     elif not tool_output[\"parsed\"]:\n",
        "#         print(\"Failed to invoke tool!\")\n",
        "#         raise ValueError(\n",
        "#             f\"You did not use the provided tool! Be sure to invoke the tool to structure the output.\"\n",
        "#         )\n",
        "#     return tool_output\n",
        "\n",
        "\n",
        "# # Chain with output check\n",
        "# code_chain_claude_raw = (\n",
        "#     code_gen_prompt_claude | structured_llm_claude | check_claude_output\n",
        "# )\n",
        "\n",
        "\n",
        "# def insert_errors(inputs):\n",
        "#     \"\"\"Insert errors for tool parsing in the messages\"\"\"\n",
        "\n",
        "#     # Get errors\n",
        "#     error = inputs[\"error\"]\n",
        "#     messages = inputs[\"messages\"]\n",
        "#     messages += [\n",
        "#         (\n",
        "#             \"assistant\",\n",
        "#             f\"Retry. You are required to fix the parsing errors: {error} \\n\\n You must invoke the provided tool.\",\n",
        "#         )\n",
        "#     ]\n",
        "#     return {\n",
        "#         \"messages\": messages,\n",
        "#         \"context\": inputs[\"context\"],\n",
        "#     }\n",
        "\n",
        "\n",
        "# # This will be run as a fallback chain\n",
        "# fallback_chain = insert_errors | code_chain_claude_raw\n",
        "# N = 3  # Max re-tries\n",
        "# code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(\n",
        "#     fallbacks=[fallback_chain] * N, exception_key=\"error\"\n",
        "# )\n",
        "\n",
        "\n",
        "# def parse_output(solution):\n",
        "#     \"\"\"When we add 'include_raw=True' to structured output,\n",
        "#     it will return a dict w 'raw', 'parsed', 'parsing_error'.\"\"\"\n",
        "\n",
        "#     return solution[\"parsed\"]\n",
        "\n",
        "\n",
        "# # With re-try to correct for failure to invoke tool\n",
        "# # TODO: Annoying errors w/ \"user\" vs \"assistant\"\n",
        "# # Roles must alternate between \"user\" and \"assistant\", but found multiple \"user\" roles in a row\n",
        "# code_gen_chain = code_gen_chain_re_try | parse_output\n",
        "\n",
        "# # No re-try\n",
        "# code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "9f14750f-dddc-485b-ba29-5392cdf4ba43",
      "metadata": {
        "scrolled": true,
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "9f14750f-dddc-485b-ba29-5392cdf4ba43",
        "outputId": "22ecb904-e065-4239-eea1-586c27285e03"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "code(prefix='Fine-tuning Llama 3 with the Mol-Instructions dataset involves several steps. Below is a Python code example that demonstrates how to fine-tune the Llama 3 model using the Hugging Face Transformers library and the Mol-Instructions dataset. The code includes necessary imports, dataset loading, model loading, training setup, and the training loop.', imports='import torch\\nfrom transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset', code=\"# Load the Mol-Instructions dataset\\ndataset = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Load the Llama 3 model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = LlamaForCausalLM.from_pretrained(model_name)\\ntokenizer = LlamaTokenizer.from_pretrained(model_name)\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n    return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_datasets = dataset.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n    output_dir='./results',\\n    evaluation_strategy='epoch',\\n    learning_rate=2e-5,\\n    per_device_train_batch_size=4,\\n    per_device_eval_batch_size=4,\\n    num_train_epochs=3,\\n    weight_decay=0.01,\\n    save_total_limit=2,\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n    model=model,\\n    args=training_args,\\n    train_dataset=tokenized_datasets['train'],\\n    eval_dataset=tokenized_datasets['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\", description='Schema for code solutions to questions about Python.')"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "# Test\n",
        "question = \"How do I fine-tune Llama 3 with the Mol-Instructions dataset?\"\n",
        "solution = code_gen_chain.invoke(\n",
        "    {\"context\": concatenated_content, \"messages\": [(\"user\", question)]}\n",
        ")\n",
        "solution"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "131f2055-2f64-4d19-a3d1-2d3cb8b42894",
      "metadata": {
        "id": "131f2055-2f64-4d19-a3d1-2d3cb8b42894"
      },
      "source": [
        "## State\n",
        "\n",
        "Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "id": "c185f1a2-e943-4bed-b833-4243c9c64092",
      "metadata": {
        "id": "c185f1a2-e943-4bed-b833-4243c9c64092"
      },
      "outputs": [],
      "source": [
        "from typing import Dict, TypedDict, List\n",
        "\n",
        "\n",
        "class GraphState(TypedDict):\n",
        "    \"\"\"\n",
        "    Represents the state of our graph.\n",
        "\n",
        "    Attributes:\n",
        "        error : Binary flag for control flow to indicate whether test error was tripped\n",
        "        messages : With user question, error messages, reasoning\n",
        "        generation : Code solution\n",
        "        iterations : Number of tries\n",
        "    \"\"\"\n",
        "\n",
        "    error: str\n",
        "    messages: List\n",
        "    generation: str\n",
        "    iterations: int"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "64454465-26a3-40de-ad85-bcf59a2c3086",
      "metadata": {
        "id": "64454465-26a3-40de-ad85-bcf59a2c3086"
      },
      "source": [
        "## Graph\n",
        "\n",
        "Our graph lays out the logical flow shown in the figure above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "b70e8301-63ae-4f7e-ad8f-c9a052fe3566",
      "metadata": {
        "id": "b70e8301-63ae-4f7e-ad8f-c9a052fe3566"
      },
      "outputs": [],
      "source": [
        "from operator import itemgetter\n",
        "from langchain_core.pydantic_v1 import BaseModel, Field\n",
        "from langchain_core.runnables import RunnablePassthrough\n",
        "from langchain_core.prompts import PromptTemplate\n",
        "\n",
        "### Parameter\n",
        "\n",
        "# Max tries\n",
        "max_iterations = 3\n",
        "# Reflect\n",
        "# flag = 'reflect'\n",
        "flag = \"do not reflect\"\n",
        "\n",
        "### Nodes\n",
        "\n",
        "\n",
        "def generate(state: GraphState):\n",
        "    \"\"\"\n",
        "    Generate a code solution\n",
        "\n",
        "    Args:\n",
        "        state (dict): The current graph state\n",
        "\n",
        "    Returns:\n",
        "        state (dict): New key added to state, generation\n",
        "    \"\"\"\n",
        "\n",
        "    print(\"---GENERATING CODE SOLUTION---\")\n",
        "\n",
        "    # State\n",
        "    messages = state[\"messages\"]\n",
        "    iterations = state[\"iterations\"]\n",
        "    error = state[\"error\"]\n",
        "\n",
        "    # We have been routed back to generation with an error\n",
        "    if error == \"yes\":\n",
        "        messages += [\n",
        "            (\n",
        "                \"user\",\n",
        "                \"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:\",\n",
        "            )\n",
        "        ]\n",
        "\n",
        "    # Solution\n",
        "    code_solution = code_gen_chain.invoke(\n",
        "        {\"context\": concatenated_content, \"messages\": messages}\n",
        "    )\n",
        "    messages += [\n",
        "        (\n",
        "            \"assistant\",\n",
        "            f\"{code_solution.prefix} \\n Imports: {code_solution.imports} \\n Code: {code_solution.code}\",\n",
        "        )\n",
        "    ]\n",
        "\n",
        "    # Increment\n",
        "    iterations = iterations + 1\n",
        "    return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n",
        "\n",
        "\n",
        "def code_check(state: GraphState):\n",
        "    \"\"\"\n",
        "    Check code\n",
        "\n",
        "    Args:\n",
        "        state (dict): The current graph state\n",
        "\n",
        "    Returns:\n",
        "        state (dict): New key added to state, error\n",
        "    \"\"\"\n",
        "\n",
        "    print(\"---CHECKING CODE---\")\n",
        "\n",
        "    # State\n",
        "    messages = state[\"messages\"]\n",
        "    code_solution = state[\"generation\"]\n",
        "    iterations = state[\"iterations\"]\n",
        "\n",
        "    # Get solution components\n",
        "    prefix = code_solution.prefix\n",
        "    imports = code_solution.imports\n",
        "    code = code_solution.code\n",
        "\n",
        "    # Check imports\n",
        "    try:\n",
        "        exec(imports)\n",
        "    except Exception as e:\n",
        "        print(\"---CODE IMPORT CHECK: FAILED---\")\n",
        "        error_message = [(\"user\", f\"Your solution failed the import test: {e}\")]\n",
        "        messages += error_message\n",
        "        return {\n",
        "            \"generation\": code_solution,\n",
        "            \"messages\": messages,\n",
        "            \"iterations\": iterations,\n",
        "            \"error\": \"yes\",\n",
        "        }\n",
        "\n",
        "    # Check execution\n",
        "    try:\n",
        "        exec(imports + \"\\n\" + code)\n",
        "    except Exception as e:\n",
        "        print(\"---CODE BLOCK CHECK: FAILED---\")\n",
        "        error_message = [(\"user\", f\"Your solution failed the code execution test: {e}\")]\n",
        "        messages += error_message\n",
        "        return {\n",
        "            \"generation\": code_solution,\n",
        "            \"messages\": messages,\n",
        "            \"iterations\": iterations,\n",
        "            \"error\": \"yes\",\n",
        "        }\n",
        "\n",
        "    # No errors\n",
        "    print(\"---NO CODE TEST FAILURES---\")\n",
        "    return {\n",
        "        \"generation\": code_solution,\n",
        "        \"messages\": messages,\n",
        "        \"iterations\": iterations,\n",
        "        \"error\": \"no\",\n",
        "    }\n",
        "\n",
        "\n",
        "def reflect(state: GraphState):\n",
        "    \"\"\"\n",
        "    Reflect on errors\n",
        "\n",
        "    Args:\n",
        "        state (dict): The current graph state\n",
        "\n",
        "    Returns:\n",
        "        state (dict): New key added to state, generation\n",
        "    \"\"\"\n",
        "\n",
        "    print(\"---GENERATING CODE SOLUTION---\")\n",
        "\n",
        "    # State\n",
        "    messages = state[\"messages\"]\n",
        "    iterations = state[\"iterations\"]\n",
        "    code_solution = state[\"generation\"]\n",
        "\n",
        "    # Prompt reflection\n",
        "    reflection_message = [\n",
        "        (\n",
        "            \"user\",\n",
        "            \"\"\"You tried to solve this problem and failed a unit test. Reflect on this failure\n",
        "                                    given the provided documentation. Write a few key suggestions based on the\n",
        "                                    documentation to avoid making this mistake again.\"\"\",\n",
        "        )\n",
        "    ]\n",
        "\n",
        "    # Add reflection\n",
        "    reflections = code_gen_chain.invoke(\n",
        "        {\"context\": concatenated_content, \"messages\": messages}\n",
        "    )\n",
        "    messages += [(\"assistant\", f\"Here are reflections on the error: {reflections}\")]\n",
        "    return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n",
        "\n",
        "\n",
        "### Edges\n",
        "\n",
        "\n",
        "def decide_to_finish(state: GraphState):\n",
        "    \"\"\"\n",
        "    Determines whether to finish.\n",
        "\n",
        "    Args:\n",
        "        state (dict): The current graph state\n",
        "\n",
        "    Returns:\n",
        "        str: Next node to call\n",
        "    \"\"\"\n",
        "    error = state[\"error\"]\n",
        "    iterations = state[\"iterations\"]\n",
        "\n",
        "    if error == \"no\" or iterations == max_iterations:\n",
        "        print(\"---DECISION: FINISH---\")\n",
        "        return \"end\"\n",
        "    else:\n",
        "        print(\"---DECISION: RE-TRY SOLUTION---\")\n",
        "        if flag == \"reflect\":\n",
        "            return \"reflect\"\n",
        "        else:\n",
        "            return \"generate\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "f66b4e00-4731-42c8-bc38-72dd0ff7c92c",
      "metadata": {
        "id": "f66b4e00-4731-42c8-bc38-72dd0ff7c92c"
      },
      "outputs": [],
      "source": [
        "from langgraph.graph import END, StateGraph\n",
        "\n",
        "workflow = StateGraph(GraphState)\n",
        "\n",
        "# Define the nodes\n",
        "workflow.add_node(\"generate\", generate)  # generation solution\n",
        "workflow.add_node(\"check_code\", code_check)  # check code\n",
        "workflow.add_node(\"reflect\", reflect)  # reflect\n",
        "\n",
        "# Build graph\n",
        "workflow.set_entry_point(\"generate\")\n",
        "workflow.add_edge(\"generate\", \"check_code\")\n",
        "workflow.add_conditional_edges(\n",
        "    \"check_code\",\n",
        "    decide_to_finish,\n",
        "    {\n",
        "        \"end\": END,\n",
        "        \"reflect\": \"reflect\",\n",
        "        \"generate\": \"generate\",\n",
        "    },\n",
        ")\n",
        "workflow.add_edge(\"reflect\", \"generate\")\n",
        "app = workflow.compile()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "9bcaafe4-ddcf-4fab-8620-2d9b6c508f98",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "9bcaafe4-ddcf-4fab-8620-2d9b6c508f98",
        "outputId": "4020e91e-b1f8-40b1-ebbd-b06385321068"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "---GENERATING CODE SOLUTION---\n",
            "---CHECKING CODE---\n",
            "---CODE IMPORT CHECK: FAILED---\n",
            "---DECISION: RE-TRY SOLUTION---\n",
            "---GENERATING CODE SOLUTION---\n",
            "---CHECKING CODE---\n",
            "---CODE IMPORT CHECK: FAILED---\n",
            "---DECISION: RE-TRY SOLUTION---\n",
            "---GENERATING CODE SOLUTION---\n",
            "---CHECKING CODE---\n",
            "---CODE IMPORT CHECK: FAILED---\n",
            "---DECISION: FINISH---\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'error': 'yes',\n",
              " 'messages': [('user',\n",
              "   'How can I fine-tune Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset?'),\n",
              "  ('assistant',\n",
              "   \"Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset involves several steps. The following code demonstrates how to set up and execute the fine-tuning process using the Hugging Face Transformers library and PyTorch. The process includes loading the pre-trained model, preparing the dataset, and running the training loop. \\n Imports: import torch\\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset \\n Code: # Load the pre-trained model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\\ntokenizer = AutoTokenizer.from_pretrained(model_name)\\n\\n# Load the Mol-Instructions dataset\\ndataset = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n    return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_datasets = dataset.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n    output_dir='./results',\\n    evaluation_strategy='epoch',\\n    learning_rate=2e-5,\\n    per_device_train_batch_size=4,\\n    per_device_eval_batch_size=4,\\n    num_train_epochs=3,\\n    weight_decay=0.01,\\n    logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n    model=model,\\n    args=training_args,\\n    train_dataset=tokenized_datasets['train'],\\n    eval_dataset=tokenized_datasets['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\"),\n",
              "  ('user', \"Your solution failed the import test: No module named 'datasets'\"),\n",
              "  ('user',\n",
              "   'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n",
              "  ('assistant',\n",
              "   \"Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset involves several steps. The following code demonstrates how to set up and execute the fine-tuning process using the Hugging Face Transformers library and PyTorch. The process includes loading the pre-trained model, preparing the dataset, and running the training loop. \\n Imports: import torch\\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset \\n Code: # Load the pre-trained model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\\ntokenizer = AutoTokenizer.from_pretrained(model_name)\\n\\n# Load the Mol-Instructions dataset\\ndataset = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n    return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_datasets = dataset.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n    output_dir='./results',\\n    evaluation_strategy='epoch',\\n    learning_rate=2e-5,\\n    per_device_train_batch_size=4,\\n    per_device_eval_batch_size=4,\\n    num_train_epochs=3,\\n    weight_decay=0.01,\\n    logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n    model=model,\\n    args=training_args,\\n    train_dataset=tokenized_datasets['train'],\\n    eval_dataset=tokenized_datasets['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\"),\n",
              "  ('user', \"Your solution failed the import test: No module named 'datasets'\"),\n",
              "  ('user',\n",
              "   'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n",
              "  ('assistant',\n",
              "   \"Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset involves several steps. The following code demonstrates how to set up and execute the fine-tuning process using the Hugging Face Transformers library and PyTorch. The process includes loading the pre-trained model, preparing the dataset, and running the training loop. \\n Imports: import torch\\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset \\n Code: # Load the pre-trained model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\\ntokenizer = AutoTokenizer.from_pretrained(model_name)\\n\\n# Load the Mol-Instructions dataset\\ndataset = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n    return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_datasets = dataset.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n    output_dir='./results',\\n    evaluation_strategy='epoch',\\n    learning_rate=2e-5,\\n    per_device_train_batch_size=4,\\n    per_device_eval_batch_size=4,\\n    num_train_epochs=3,\\n    weight_decay=0.01,\\n    logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n    model=model,\\n    args=training_args,\\n    train_dataset=tokenized_datasets['train'],\\n    eval_dataset=tokenized_datasets['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\"),\n",
              "  ('user',\n",
              "   \"Your solution failed the import test: No module named 'datasets'\")],\n",
              " 'generation': code(prefix='Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset involves several steps. The following code demonstrates how to set up and execute the fine-tuning process using the Hugging Face Transformers library and PyTorch. The process includes loading the pre-trained model, preparing the dataset, and running the training loop.', imports='import torch\\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset', code=\"# Load the pre-trained model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\\ntokenizer = AutoTokenizer.from_pretrained(model_name)\\n\\n# Load the Mol-Instructions dataset\\ndataset = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n    return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_datasets = dataset.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n    output_dir='./results',\\n    evaluation_strategy='epoch',\\n    learning_rate=2e-5,\\n    per_device_train_batch_size=4,\\n    per_device_eval_batch_size=4,\\n    num_train_epochs=3,\\n    weight_decay=0.01,\\n    logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n    model=model,\\n    args=training_args,\\n    train_dataset=tokenized_datasets['train'],\\n    eval_dataset=tokenized_datasets['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\", description='Schema for code solutions to questions about Python.'),\n",
              " 'iterations': 3}"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ],
      "source": [
        "question = \"How can I fine-tune Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset?\"\n",
        "app.invoke({\"messages\": [(\"user\", question)], \"iterations\": 0})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "744f48a5-9ad3-4342-899f-7dd4266a9a15",
      "metadata": {
        "id": "744f48a5-9ad3-4342-899f-7dd4266a9a15"
      },
      "source": [
        "## Eval"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "89852874-b538-4c8d-a4c3-1d68302db492",
      "metadata": {
        "id": "89852874-b538-4c8d-a4c3-1d68302db492"
      },
      "source": [
        "[Here](https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d) is a public dataset of LCEL questions.\n",
        "\n",
        "I saved this as `test-LCEL-code-gen`.\n",
        "\n",
        "You can also find the csv [here](https://github.com/langchain-ai/lcel-teacher/blob/main/eval/eval.csv)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "678e8954-56b5-4cc6-be26-f7f2a060b242",
      "metadata": {
        "id": "678e8954-56b5-4cc6-be26-f7f2a060b242"
      },
      "outputs": [],
      "source": [
        "import langsmith\n",
        "\n",
        "client = langsmith.Client()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "ef7cf662-7a6f-4dee-965c-6309d4045feb",
      "metadata": {
        "id": "ef7cf662-7a6f-4dee-965c-6309d4045feb"
      },
      "outputs": [],
      "source": [
        "# Clone the dataset to your tenant to use it\n",
        "public_dataset = (\n",
        "    \"https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d\"\n",
        ")\n",
        "client.clone_public_dataset(public_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9d171396-022b-47ec-a741-c782aff9fdae",
      "metadata": {
        "id": "9d171396-022b-47ec-a741-c782aff9fdae"
      },
      "source": [
        "Custom evals."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "455a34ea-52cb-4ae5-9f4a-7e4a08cd0c09",
      "metadata": {
        "id": "455a34ea-52cb-4ae5-9f4a-7e4a08cd0c09"
      },
      "outputs": [],
      "source": [
        "from langsmith.schemas import Example, Run\n",
        "\n",
        "\n",
        "def check_import(run: Run, example: Example) -> dict:\n",
        "    imports = run.outputs.get(\"imports\")\n",
        "    try:\n",
        "        exec(imports)\n",
        "        return {\"key\": \"import_check\", \"score\": 1}\n",
        "    except:\n",
        "        return {\"key\": \"import_check\", \"score\": 0}\n",
        "\n",
        "\n",
        "def check_execution(run: Run, example: Example) -> dict:\n",
        "    imports = run.outputs.get(\"imports\")\n",
        "    code = run.outputs.get(\"code\")\n",
        "    try:\n",
        "        exec(imports + \"\\n\" + code)\n",
        "        return {\"key\": \"code_execution_check\", \"score\": 1}\n",
        "    except:\n",
        "        return {\"key\": \"code_execution_check\", \"score\": 0}"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c90bf261-0d94-4779-bbde-c76adeefe3d7",
      "metadata": {
        "id": "c90bf261-0d94-4779-bbde-c76adeefe3d7"
      },
      "source": [
        "Compare LangGraph to Context Stuffing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "c8fa6bcb-b245-4422-b79a-582cd8a7d7ea",
      "metadata": {
        "id": "c8fa6bcb-b245-4422-b79a-582cd8a7d7ea"
      },
      "outputs": [],
      "source": [
        "def predict_base_case(example: dict):\n",
        "    \"\"\"Context stuffing\"\"\"\n",
        "    solution = code_gen_chain.invoke(\n",
        "        {\"context\": concatenated_content, \"messages\": [(\"user\", example[\"question\"])]}\n",
        "    )\n",
        "    solution_structured = structured_code_formatter.invoke([(\"code\", solution)])\n",
        "    return {\"imports\": solution_structured.imports, \"code\": solution_structured.code}\n",
        "\n",
        "\n",
        "def predict_langgraph(example: dict):\n",
        "    \"\"\"LangGraph\"\"\"\n",
        "    graph = app.invoke({\"messages\": [(\"user\", example[\"question\"])], \"iterations\": 0})\n",
        "    solution = graph[\"generation\"]\n",
        "    return {\"imports\": solution.imports, \"code\": solution.code}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "d9c57468-97f6-47d6-a5e9-c09b53bfdd83",
      "metadata": {
        "id": "d9c57468-97f6-47d6-a5e9-c09b53bfdd83"
      },
      "outputs": [],
      "source": [
        "from langsmith.evaluation import evaluate\n",
        "\n",
        "# Evaluator\n",
        "code_evalulator = [check_import, check_execution]\n",
        "\n",
        "# Dataset\n",
        "dataset_name = \"test-LCEL-code-gen\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "2dacccf0-d73f-4017-aaf0-9806ffe5bd2c",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 371
        },
        "id": "2dacccf0-d73f-4017-aaf0-9806ffe5bd2c",
        "outputId": "1b92e9f5-a268-4416-b000-f325aa8b1caf"
      },
      "outputs": [
        {
          "output_type": "error",
          "ename": "LangSmithNotFoundError",
          "evalue": "Dataset test-LCEL-code-gen not found",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mLangSmithNotFoundError\u001b[0m                    Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-16-c42adff5de4e>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Run base case\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m experiment_results_ = evaluate(\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mpredict_base_case\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mevaluators\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcode_evalulator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/evaluation/_runner.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(target, data, evaluators, summary_evaluators, metadata, experiment_prefix, description, max_concurrency, client, blocking)\u001b[0m\n\u001b[1;32m    233\u001b[0m         \u001b[0mView\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mevaluation\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mexperiment\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    234\u001b[0m     \"\"\"  # noqa: E501\n\u001b[0;32m--> 235\u001b[0;31m     return _evaluate(\n\u001b[0m\u001b[1;32m    236\u001b[0m         \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    237\u001b[0m         \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/evaluation/_runner.py\u001b[0m in \u001b[0;36m_evaluate\u001b[0;34m(target, data, evaluators, summary_evaluators, metadata, experiment_prefix, description, max_concurrency, client, blocking, experiment)\u001b[0m\n\u001b[1;32m    789\u001b[0m         \u001b[0mruns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mruns\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    790\u001b[0m         \u001b[0;31m# Create or resolve the experiment.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 791\u001b[0;31m     ).start()\n\u001b[0m\u001b[1;32m    792\u001b[0m     \u001b[0mcache_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mls_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_cache_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    793\u001b[0m     cache_path = (\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/evaluation/_runner.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1054\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1055\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_ExperimentManager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1056\u001b[0;31m         \u001b[0mfirst_example\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitertools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mislice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1057\u001b[0m         \u001b[0mproject\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_project\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfirst_example\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1058\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_print_experiment_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfirst_example\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/client.py\u001b[0m in \u001b[0;36mlist_examples\u001b[0;34m(self, dataset_id, dataset_name, example_ids, as_of, splits, inline_s3_urls, limit, metadata, **kwargs)\u001b[0m\n\u001b[1;32m   3144\u001b[0m             \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3145\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mdataset_name\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3146\u001b[0;31m             \u001b[0mdataset_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3147\u001b[0m             \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3148\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/utils.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     99\u001b[0m                     \u001b[0;34mf\" {', '.join(invalid_group_names)}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m                 )\n\u001b[0;32m--> 101\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/client.py\u001b[0m in \u001b[0;36mread_dataset\u001b[0;34m(self, dataset_name, dataset_id)\u001b[0m\n\u001b[1;32m   2400\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2401\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2402\u001b[0;31m                 raise ls_utils.LangSmithNotFoundError(\n\u001b[0m\u001b[1;32m   2403\u001b[0m                     \u001b[0;34mf\"Dataset {dataset_name} not found\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2404\u001b[0m                 )\n",
            "\u001b[0;31mLangSmithNotFoundError\u001b[0m: Dataset test-LCEL-code-gen not found"
          ]
        }
      ],
      "source": [
        "# Run base case\n",
        "experiment_results_ = evaluate(\n",
        "    predict_base_case,\n",
        "    data=dataset_name,\n",
        "    evaluators=code_evalulator,\n",
        "    experiment_prefix=f\"test-without-langgraph-{expt_llm}\",\n",
        "    max_concurrency=2,\n",
        "    metadata={\n",
        "        \"llm\": expt_llm,\n",
        "    },\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "71d90f9e-9dad-410c-a709-093d275029ae",
      "metadata": {
        "id": "71d90f9e-9dad-410c-a709-093d275029ae"
      },
      "outputs": [],
      "source": [
        "# Run with langgraph\n",
        "experiment_results = evaluate(\n",
        "    predict_langgraph,\n",
        "    data=dataset_name,\n",
        "    evaluators=code_evalulator,\n",
        "    experiment_prefix=f\"test-with-langgraph-{expt_llm}-{flag}\",\n",
        "    max_concurrency=2,\n",
        "    metadata={\n",
        "        \"llm\": expt_llm,\n",
        "        \"feedback\": flag,\n",
        "    },\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d69da747-b4ea-455d-9314-60c3d9d30549",
      "metadata": {
        "id": "d69da747-b4ea-455d-9314-60c3d9d30549"
      },
      "source": [
        "`Results:`\n",
        "\n",
        "* `LangGraph outperforms base case`: adding re-try loop improve performance\n",
        "* `Reflection did not help`: reflection prior to re-try regression vs just passing errors directly back to the LLM\n",
        "* `GPT-4 outperforms Claude3`: Claude3 had 3 and 1 run fail due to tool-use error for Opus and Haiku, respectively\n",
        "\n",
        "https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a42333c3-c098-4576-ae2a-0258de64ece2",
      "metadata": {
        "id": "a42333c3-c098-4576-ae2a-0258de64ece2"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.2"
    },
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "L4"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 5
}