1043 lines (1042 with data), 63.3 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "16dc0e41-80bd-4453-b421-dcf315741bf4",
"metadata": {
"id": "16dc0e41-80bd-4453-b421-dcf315741bf4"
},
"source": [
"# Code generation with flow\n",
"\n",
"AlphaCodium presented an approach for code generation that uses control flow.\n",
"\n",
"Main idea: [construct an answer to a coding question iteratively.](https://x.com/karpathy/status/1748043513156272416?s=20).\n",
"\n",
"[AlphaCodium](https://github.com/Codium-ai/AlphaCodium) iteravely tests and improves an answer on public and AI-generated tests for a particular question.\n",
"\n",
"We will implement some of these ideas from scratch using [LangGraph](https://python.langchain.com/docs/langgraph):\n",
"\n",
"1. We start with a set of documentation specified by a user\n",
"2. We use a long context LLM to ingest it, and answer a question based upon it\n",
"3. We perform two unit tests: Check imports and code execution\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e3900420",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "e3900420",
"outputId": "5497b4e5-91cd-4ca3-aaef-20b99955f8dc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting langchain_community\n",
" Downloading langchain_community-0.0.38-py3-none-any.whl (2.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting langchain-openai\n",
" Downloading langchain_openai-0.1.7-py3-none-any.whl (34 kB)\n",
"Collecting langchain-anthropic\n",
" Downloading langchain_anthropic-0.1.12-py3-none-any.whl (16 kB)\n",
"Collecting langchain\n",
" Downloading langchain-0.1.20-py3-none-any.whl (1.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m44.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting langgraph\n",
" Downloading langgraph-0.0.49-py3-none-any.whl (77 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.6/77.6 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting bs4\n",
" Downloading bs4-0.0.2-py2.py3-none-any.whl (1.2 kB)\n",
"Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (6.0.1)\n",
"Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.0.30)\n",
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (3.9.5)\n",
"Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)\n",
" Downloading dataclasses_json-0.6.6-py3-none-any.whl (28 kB)\n",
"Collecting langchain-core<0.2.0,>=0.1.52 (from langchain_community)\n",
" Downloading langchain_core-0.1.52-py3-none-any.whl (302 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.9/302.9 kB\u001b[0m \u001b[31m31.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting langsmith<0.2.0,>=0.1.0 (from langchain_community)\n",
" Downloading langsmith-0.1.59-py3-none-any.whl (121 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.2/121.2 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (1.25.2)\n",
"Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.31.0)\n",
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (8.3.0)\n",
"Collecting openai<2.0.0,>=1.24.0 (from langchain-openai)\n",
" Downloading openai-1.30.1-py3-none-any.whl (320 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.6/320.6 kB\u001b[0m \u001b[31m31.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting tiktoken<1,>=0.7 (from langchain-openai)\n",
" Downloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m59.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting anthropic<1,>=0.23.0 (from langchain-anthropic)\n",
" Downloading anthropic-0.25.9-py3-none-any.whl (871 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m871.1/871.1 kB\u001b[0m \u001b[31m58.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: defusedxml<0.8.0,>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from langchain-anthropic) (0.7.1)\n",
"Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3)\n",
"Collecting langchain-text-splitters<0.1,>=0.0.1 (from langchain)\n",
" Downloading langchain_text_splitters-0.0.2-py3-none-any.whl (23 kB)\n",
"Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.7.1)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from bs4) (4.12.3)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.9.4)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (3.7.1)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.7.0)\n",
"Collecting httpx<1,>=0.23.0 (from anthropic<1,>=0.23.0->langchain-anthropic)\n",
" Downloading httpx-0.27.0-py3-none-any.whl (75 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.3.1)\n",
"Requirement already satisfied: tokenizers>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (0.19.1)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (4.11.0)\n",
"Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)\n",
" Downloading marshmallow-3.21.2-py3-none-any.whl (49 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.3/49.3 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)\n",
" Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n",
"Collecting jsonpatch<2.0,>=1.33 (from langchain-core<0.2.0,>=0.1.52->langchain_community)\n",
" Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n",
"Collecting packaging<24.0,>=23.2 (from langchain-core<0.2.0,>=0.1.52->langchain_community)\n",
" Downloading packaging-23.2-py3-none-any.whl (53 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.0/53.0 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.0->langchain_community)\n",
" Downloading orjson-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (142 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m142.5/142.5 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai<2.0.0,>=1.24.0->langchain-openai) (4.66.4)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.18.2 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.18.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2024.2.2)\n",
"Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain_community) (3.0.3)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken<1,>=0.7->langchain-openai) (2023.12.25)\n",
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->bs4) (2.5)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->anthropic<1,>=0.23.0->langchain-anthropic) (1.2.1)\n",
"Collecting httpcore==1.* (from httpx<1,>=0.23.0->anthropic<1,>=0.23.0->langchain-anthropic)\n",
" Downloading httpcore-1.0.5-py3-none-any.whl (77 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->anthropic<1,>=0.23.0->langchain-anthropic)\n",
" Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain-core<0.2.0,>=0.1.52->langchain_community)\n",
" Downloading jsonpointer-2.4-py2.py3-none-any.whl (7.8 kB)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (0.20.3)\n",
"Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain_community)\n",
" Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (3.14.0)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (2023.6.0)\n",
"Installing collected packages: packaging, orjson, mypy-extensions, jsonpointer, h11, typing-inspect, tiktoken, marshmallow, jsonpatch, httpcore, bs4, langsmith, httpx, dataclasses-json, openai, langchain-core, anthropic, langgraph, langchain-text-splitters, langchain-openai, langchain_community, langchain-anthropic, langchain\n",
" Attempting uninstall: packaging\n",
" Found existing installation: packaging 24.0\n",
" Uninstalling packaging-24.0:\n",
" Successfully uninstalled packaging-24.0\n",
"Successfully installed anthropic-0.25.9 bs4-0.0.2 dataclasses-json-0.6.6 h11-0.14.0 httpcore-1.0.5 httpx-0.27.0 jsonpatch-1.33 jsonpointer-2.4 langchain-0.1.20 langchain-anthropic-0.1.12 langchain-core-0.1.52 langchain-openai-0.1.7 langchain-text-splitters-0.0.2 langchain_community-0.0.38 langgraph-0.0.49 langsmith-0.1.59 marshmallow-3.21.2 mypy-extensions-1.0.0 openai-1.30.1 orjson-3.10.3 packaging-23.2 tiktoken-0.7.0 typing-inspect-0.9.0\n"
]
}
],
"source": [
"! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4"
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.environ['OPENAI_API_KEY'] = ''\n",
"os.environ['ANTHROPIC_API_KEY'] = ''\n",
"\n",
"# Optional, add tracing in LangSmith\n",
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'\n",
"os.environ[\"LANGCHAIN_PROJECT\"] = \"langgraph_code_assistant\"\n",
"os.environ['LANGCHAIN_API_KEY'] = ''"
],
"metadata": {
"id": "thz4d-ctMrIC"
},
"id": "thz4d-ctMrIC",
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"id": "38330223-d8c8-4156-82b6-93e63343bc01",
"metadata": {
"id": "38330223-d8c8-4156-82b6-93e63343bc01"
},
"source": [
"## Docs\n",
"\n",
"Load [LangChain Expression Language](https://python.langchain.com/docs/expression_language/) (LCEL) docs as an example."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c2eb35d1-4990-47dc-a5c4-208bae588a82",
"metadata": {
"id": "c2eb35d1-4990-47dc-a5c4-208bae588a82"
},
"outputs": [],
"source": [
"from bs4 import BeautifulSoup as Soup\n",
"from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader\n",
"\n",
"# LCEL docs\n",
"url = \"https://github.com/kevinkawchak/Medical-Quantum-Machine-Learning/tree/main/Code/Drug%20Discovery/Meta-Llama-3\"\n",
"loader = RecursiveUrlLoader(\n",
" url=url, max_depth=20, extractor=lambda x: Soup(x, \"html.parser\").text\n",
")\n",
"docs = loader.load()\n",
"\n",
"# Sort the list based on the URLs and get the text\n",
"d_sorted = sorted(docs, key=lambda x: x.metadata[\"source\"])\n",
"d_reversed = list(reversed(d_sorted))\n",
"concatenated_content = \"\\n\\n\\n --- \\n\\n\\n\".join(\n",
" [doc.page_content for doc in d_reversed]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "662d4ff4-1709-412f-bfed-5eb2b8d3d3dc",
"metadata": {
"id": "662d4ff4-1709-412f-bfed-5eb2b8d3d3dc"
},
"source": [
"## LLMs\n",
"\n",
"### Code solution\n",
"\n",
"Try OpenAI and [Claude3](https://docs.anthropic.com/claude/docs/models-overview) with function calling."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3ba3df70-f6b4-4ea5-a210-e10944960bc6",
"metadata": {
"id": "3ba3df70-f6b4-4ea5-a210-e10944960bc6"
},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"### OpenAI\n",
"\n",
"# Grader prompt\n",
"code_gen_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"\"\"You are a coding assistant with expertise in Python. \\n\n",
" Here is a full set of documentation: \\n ------- \\n {context} \\n ------- \\n Answer the user\n",
" question based on the above provided documentation. Ensure any code you provide can be executed \\n\n",
" with all required imports and variables defined. Structure your answer with a description of the code solution. \\n\n",
" Then list the imports. And finally list the functioning code block. Here is the user question:\"\"\",\n",
" ),\n",
" (\"placeholder\", \"{messages}\"),\n",
" ]\n",
")\n",
"\n",
"\n",
"# Data model\n",
"class code(BaseModel):\n",
" \"\"\"Code output\"\"\"\n",
"\n",
" prefix: str = Field(description=\"Description of the problem and approach\")\n",
" imports: str = Field(description=\"Code block import statements\")\n",
" code: str = Field(description=\"Code block not including import statements\")\n",
" description = \"Schema for code solutions to questions about Python.\"\n",
"\n",
"\n",
"expt_llm = \"gpt-4o-2024-05-13\"\n",
"llm = ChatOpenAI(temperature=0, model=expt_llm)\n",
"code_gen_chain = code_gen_prompt | llm.with_structured_output(code)\n",
"question = \"How do I Fine-tune Llama 3 with the Mol-Instructions dataset?\"\n",
"# solution = code_gen_chain_oai.invoke({\"context\":concatenated_content,\"messages\":[(\"user\",question)]})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cd30b67d-96db-4e51-a540-ae23fcc1f878",
"metadata": {
"id": "cd30b67d-96db-4e51-a540-ae23fcc1f878"
},
"outputs": [],
"source": [
"# from langchain_anthropic import ChatAnthropic\n",
"# from langchain_core.prompts import ChatPromptTemplate\n",
"# from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"# ### Anthropic\n",
"\n",
"# # Prompt to enforce tool use\n",
"# code_gen_prompt_claude = ChatPromptTemplate.from_messages(\n",
"# [\n",
"# (\n",
"# \"system\",\n",
"# \"\"\"<instructions> You are a coding assistant with expertise in Python. \\n\n",
"# Here is the documentation: \\n ------- \\n {context} \\n ------- \\n Answer the user question based on the \\n\n",
"# above provided documentation. Ensure any code you provide can be executed with all required imports and variables \\n\n",
"# defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \\n\n",
"# Invoke the code tool to structure the output correctly. </instructions> \\n Here is the user question:\"\"\",\n",
"# ),\n",
"# (\"placeholder\", \"{messages}\"),\n",
"# ]\n",
"# )\n",
"\n",
"\n",
"# # Data model\n",
"# class code(BaseModel):\n",
"# \"\"\"Code output\"\"\"\n",
"\n",
"# prefix: str = Field(description=\"Description of the problem and approach\")\n",
"# imports: str = Field(description=\"Code block import statements\")\n",
"# code: str = Field(description=\"Code block not including import statements\")\n",
"# description = \"Schema for code solutions to questions about Python.\"\n",
"\n",
"\n",
"# # LLM\n",
"# # expt_llm = \"claude-3-haiku-20240307\"\n",
"# expt_llm = \"claude-3-opus-20240229\"\n",
"# llm = ChatAnthropic(\n",
"# model=expt_llm,\n",
"# default_headers={\"anthropic-beta\": \"tools-2024-04-04\"},\n",
"# )\n",
"\n",
"# structured_llm_claude = llm.with_structured_output(code, include_raw=True)\n",
"\n",
"\n",
"# # Optional: Check for errors in case tool use is flaky\n",
"# def check_claude_output(tool_output):\n",
"# \"\"\"Check for parse error or failure to call the tool\"\"\"\n",
"\n",
"# # Error with parsing\n",
"# if tool_output[\"parsing_error\"]:\n",
"# # Report back output and parsing errors\n",
"# print(\"Parsing error!\")\n",
"# raw_output = str(code_output[\"raw\"].content)\n",
"# error = tool_output[\"parsing_error\"]\n",
"# raise ValueError(\n",
"# f\"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \\n Parse error: {error}\"\n",
"# )\n",
"\n",
"# # Tool was not invoked\n",
"# elif not tool_output[\"parsed\"]:\n",
"# print(\"Failed to invoke tool!\")\n",
"# raise ValueError(\n",
"# f\"You did not use the provided tool! Be sure to invoke the tool to structure the output.\"\n",
"# )\n",
"# return tool_output\n",
"\n",
"\n",
"# # Chain with output check\n",
"# code_chain_claude_raw = (\n",
"# code_gen_prompt_claude | structured_llm_claude | check_claude_output\n",
"# )\n",
"\n",
"\n",
"# def insert_errors(inputs):\n",
"# \"\"\"Insert errors for tool parsing in the messages\"\"\"\n",
"\n",
"# # Get errors\n",
"# error = inputs[\"error\"]\n",
"# messages = inputs[\"messages\"]\n",
"# messages += [\n",
"# (\n",
"# \"assistant\",\n",
"# f\"Retry. You are required to fix the parsing errors: {error} \\n\\n You must invoke the provided tool.\",\n",
"# )\n",
"# ]\n",
"# return {\n",
"# \"messages\": messages,\n",
"# \"context\": inputs[\"context\"],\n",
"# }\n",
"\n",
"\n",
"# # This will be run as a fallback chain\n",
"# fallback_chain = insert_errors | code_chain_claude_raw\n",
"# N = 3 # Max re-tries\n",
"# code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(\n",
"# fallbacks=[fallback_chain] * N, exception_key=\"error\"\n",
"# )\n",
"\n",
"\n",
"# def parse_output(solution):\n",
"# \"\"\"When we add 'include_raw=True' to structured output,\n",
"# it will return a dict w 'raw', 'parsed', 'parsing_error'.\"\"\"\n",
"\n",
"# return solution[\"parsed\"]\n",
"\n",
"\n",
"# # With re-try to correct for failure to invoke tool\n",
"# # TODO: Annoying errors w/ \"user\" vs \"assistant\"\n",
"# # Roles must alternate between \"user\" and \"assistant\", but found multiple \"user\" roles in a row\n",
"# code_gen_chain = code_gen_chain_re_try | parse_output\n",
"\n",
"# # No re-try\n",
"# code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9f14750f-dddc-485b-ba29-5392cdf4ba43",
"metadata": {
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "9f14750f-dddc-485b-ba29-5392cdf4ba43",
"outputId": "f7ff980b-e9de-4c52-d400-ce259e10b215"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"code(prefix='Fine-tuning Llama 3 with the Mol-Instructions dataset involves several steps. Below is a Python code example that demonstrates how to fine-tune the Llama 3 model using the Hugging Face Transformers library and the Mol-Instructions dataset. This example assumes you have access to the necessary hardware (e.g., a GPU) and have installed the required libraries.', imports='import torch\\nfrom transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset', code=\"# Load the pre-trained Llama 3 model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = LlamaForCausalLM.from_pretrained(model_name)\\ntokenizer = LlamaTokenizer.from_pretrained(model_name)\\n\\n# Load the Mol-Instructions dataset\\ndataset = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_datasets = dataset.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n output_dir='./results',\\n evaluation_strategy='epoch',\\n learning_rate=2e-5,\\n per_device_train_batch_size=4,\\n per_device_eval_batch_size=4,\\n num_train_epochs=3,\\n weight_decay=0.01,\\n logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n model=model,\\n args=training_args,\\n train_dataset=tokenized_datasets['train'],\\n eval_dataset=tokenized_datasets['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\", description='This code demonstrates how to fine-tune the Llama 3 model using the Mol-Instructions dataset. It includes loading the pre-trained model and tokenizer, loading and tokenizing the dataset, setting training arguments, and initializing and running the Trainer for fine-tuning.')"
]
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"# Test\n",
"question = \"How do I fine-tune Llama 3 with the Mol-Instructions dataset?\"\n",
"solution = code_gen_chain.invoke(\n",
" {\"context\": concatenated_content, \"messages\": [(\"user\", question)]}\n",
")\n",
"solution"
]
},
{
"cell_type": "markdown",
"id": "131f2055-2f64-4d19-a3d1-2d3cb8b42894",
"metadata": {
"id": "131f2055-2f64-4d19-a3d1-2d3cb8b42894"
},
"source": [
"## State\n",
"\n",
"Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c185f1a2-e943-4bed-b833-4243c9c64092",
"metadata": {
"id": "c185f1a2-e943-4bed-b833-4243c9c64092"
},
"outputs": [],
"source": [
"from typing import Dict, TypedDict, List\n",
"\n",
"\n",
"class GraphState(TypedDict):\n",
" \"\"\"\n",
" Represents the state of our graph.\n",
"\n",
" Attributes:\n",
" error : Binary flag for control flow to indicate whether test error was tripped\n",
" messages : With user question, error messages, reasoning\n",
" generation : Code solution\n",
" iterations : Number of tries\n",
" \"\"\"\n",
"\n",
" error: str\n",
" messages: List\n",
" generation: str\n",
" iterations: int"
]
},
{
"cell_type": "markdown",
"id": "64454465-26a3-40de-ad85-bcf59a2c3086",
"metadata": {
"id": "64454465-26a3-40de-ad85-bcf59a2c3086"
},
"source": [
"## Graph\n",
"\n",
"Our graph lays out the logical flow shown in the figure above."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b70e8301-63ae-4f7e-ad8f-c9a052fe3566",
"metadata": {
"id": "b70e8301-63ae-4f7e-ad8f-c9a052fe3566"
},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.prompts import PromptTemplate\n",
"\n",
"### Parameter\n",
"\n",
"# Max tries\n",
"max_iterations = 3\n",
"# Reflect\n",
"# flag = 'reflect'\n",
"flag = \"do not reflect\"\n",
"\n",
"### Nodes\n",
"\n",
"\n",
"def generate(state: GraphState):\n",
" \"\"\"\n",
" Generate a code solution\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): New key added to state, generation\n",
" \"\"\"\n",
"\n",
" print(\"---GENERATING CODE SOLUTION---\")\n",
"\n",
" # State\n",
" messages = state[\"messages\"]\n",
" iterations = state[\"iterations\"]\n",
" error = state[\"error\"]\n",
"\n",
" # We have been routed back to generation with an error\n",
" if error == \"yes\":\n",
" messages += [\n",
" (\n",
" \"user\",\n",
" \"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:\",\n",
" )\n",
" ]\n",
"\n",
" # Solution\n",
" code_solution = code_gen_chain.invoke(\n",
" {\"context\": concatenated_content, \"messages\": messages}\n",
" )\n",
" messages += [\n",
" (\n",
" \"assistant\",\n",
" f\"{code_solution.prefix} \\n Imports: {code_solution.imports} \\n Code: {code_solution.code}\",\n",
" )\n",
" ]\n",
"\n",
" # Increment\n",
" iterations = iterations + 1\n",
" return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n",
"\n",
"\n",
"def code_check(state: GraphState):\n",
" \"\"\"\n",
" Check code\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): New key added to state, error\n",
" \"\"\"\n",
"\n",
" print(\"---CHECKING CODE---\")\n",
"\n",
" # State\n",
" messages = state[\"messages\"]\n",
" code_solution = state[\"generation\"]\n",
" iterations = state[\"iterations\"]\n",
"\n",
" # Get solution components\n",
" prefix = code_solution.prefix\n",
" imports = code_solution.imports\n",
" code = code_solution.code\n",
"\n",
" # Check imports\n",
" try:\n",
" exec(imports)\n",
" except Exception as e:\n",
" print(\"---CODE IMPORT CHECK: FAILED---\")\n",
" error_message = [(\"user\", f\"Your solution failed the import test: {e}\")]\n",
" messages += error_message\n",
" return {\n",
" \"generation\": code_solution,\n",
" \"messages\": messages,\n",
" \"iterations\": iterations,\n",
" \"error\": \"yes\",\n",
" }\n",
"\n",
" # Check execution\n",
" try:\n",
" exec(imports + \"\\n\" + code)\n",
" except Exception as e:\n",
" print(\"---CODE BLOCK CHECK: FAILED---\")\n",
" error_message = [(\"user\", f\"Your solution failed the code execution test: {e}\")]\n",
" messages += error_message\n",
" return {\n",
" \"generation\": code_solution,\n",
" \"messages\": messages,\n",
" \"iterations\": iterations,\n",
" \"error\": \"yes\",\n",
" }\n",
"\n",
" # No errors\n",
" print(\"---NO CODE TEST FAILURES---\")\n",
" return {\n",
" \"generation\": code_solution,\n",
" \"messages\": messages,\n",
" \"iterations\": iterations,\n",
" \"error\": \"no\",\n",
" }\n",
"\n",
"\n",
"def reflect(state: GraphState):\n",
" \"\"\"\n",
" Reflect on errors\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" state (dict): New key added to state, generation\n",
" \"\"\"\n",
"\n",
" print(\"---GENERATING CODE SOLUTION---\")\n",
"\n",
" # State\n",
" messages = state[\"messages\"]\n",
" iterations = state[\"iterations\"]\n",
" code_solution = state[\"generation\"]\n",
"\n",
" # Prompt reflection\n",
" reflection_message = [\n",
" (\n",
" \"user\",\n",
" \"\"\"You tried to solve this problem and failed a unit test. Reflect on this failure\n",
" given the provided documentation. Write a few key suggestions based on the\n",
" documentation to avoid making this mistake again.\"\"\",\n",
" )\n",
" ]\n",
"\n",
" # Add reflection\n",
" reflections = code_gen_chain.invoke(\n",
" {\"context\": concatenated_content, \"messages\": messages}\n",
" )\n",
" messages += [(\"assistant\", f\"Here are reflections on the error: {reflections}\")]\n",
" return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n",
"\n",
"\n",
"### Edges\n",
"\n",
"\n",
"def decide_to_finish(state: GraphState):\n",
" \"\"\"\n",
" Determines whether to finish.\n",
"\n",
" Args:\n",
" state (dict): The current graph state\n",
"\n",
" Returns:\n",
" str: Next node to call\n",
" \"\"\"\n",
" error = state[\"error\"]\n",
" iterations = state[\"iterations\"]\n",
"\n",
" if error == \"no\" or iterations == max_iterations:\n",
" print(\"---DECISION: FINISH---\")\n",
" return \"end\"\n",
" else:\n",
" print(\"---DECISION: RE-TRY SOLUTION---\")\n",
" if flag == \"reflect\":\n",
" return \"reflect\"\n",
" else:\n",
" return \"generate\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f66b4e00-4731-42c8-bc38-72dd0ff7c92c",
"metadata": {
"id": "f66b4e00-4731-42c8-bc38-72dd0ff7c92c"
},
"outputs": [],
"source": [
"from langgraph.graph import END, StateGraph\n",
"\n",
"workflow = StateGraph(GraphState)\n",
"\n",
"# Define the nodes\n",
"workflow.add_node(\"generate\", generate) # generation solution\n",
"workflow.add_node(\"check_code\", code_check) # check code\n",
"workflow.add_node(\"reflect\", reflect) # reflect\n",
"\n",
"# Build graph\n",
"workflow.set_entry_point(\"generate\")\n",
"workflow.add_edge(\"generate\", \"check_code\")\n",
"workflow.add_conditional_edges(\n",
" \"check_code\",\n",
" decide_to_finish,\n",
" {\n",
" \"end\": END,\n",
" \"reflect\": \"reflect\",\n",
" \"generate\": \"generate\",\n",
" },\n",
")\n",
"workflow.add_edge(\"reflect\", \"generate\")\n",
"app = workflow.compile()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9bcaafe4-ddcf-4fab-8620-2d9b6c508f98",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "9bcaafe4-ddcf-4fab-8620-2d9b6c508f98",
"outputId": "c1f7eed0-9ce3-4d9b-99d2-6039d2fbc882"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"---GENERATING CODE SOLUTION---\n",
"---CHECKING CODE---\n",
"---CODE IMPORT CHECK: FAILED---\n",
"---DECISION: RE-TRY SOLUTION---\n",
"---GENERATING CODE SOLUTION---\n",
"---CHECKING CODE---\n",
"---CODE IMPORT CHECK: FAILED---\n",
"---DECISION: RE-TRY SOLUTION---\n",
"---GENERATING CODE SOLUTION---\n",
"---CHECKING CODE---\n",
"---CODE IMPORT CHECK: FAILED---\n",
"---DECISION: FINISH---\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'error': 'yes',\n",
" 'messages': [('user',\n",
" 'How can I fine-tune Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset?'),\n",
" ('assistant',\n",
" \"Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset \\n Imports: import torch\\nfrom transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset \\n Code: # Load the model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = LlamaForCausalLM.from_pretrained(model_name)\\ntokenizer = LlamaTokenizer.from_pretrained(model_name)\\n\\n# Load the dataset\\nmol_instructions = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_dataset = mol_instructions.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n output_dir='./results',\\n evaluation_strategy='epoch',\\n learning_rate=2e-5,\\n per_device_train_batch_size=4,\\n per_device_eval_batch_size=4,\\n num_train_epochs=3,\\n weight_decay=0.01,\\n logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n model=model,\\n args=training_args,\\n train_dataset=tokenized_dataset['train'],\\n eval_dataset=tokenized_dataset['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\"),\n",
" ('user', \"Your solution failed the import test: No module named 'datasets'\"),\n",
" ('user',\n",
" 'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n",
" ('assistant',\n",
" \"Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset \\n Imports: import torch\\nfrom transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset \\n Code: # Load the model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = LlamaForCausalLM.from_pretrained(model_name)\\ntokenizer = LlamaTokenizer.from_pretrained(model_name)\\n\\n# Load the dataset\\nmol_instructions = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_dataset = mol_instructions.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n output_dir='./results',\\n evaluation_strategy='epoch',\\n learning_rate=2e-5,\\n per_device_train_batch_size=4,\\n per_device_eval_batch_size=4,\\n num_train_epochs=3,\\n weight_decay=0.01,\\n logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n model=model,\\n args=training_args,\\n train_dataset=tokenized_dataset['train'],\\n eval_dataset=tokenized_dataset['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\"),\n",
" ('user', \"Your solution failed the import test: No module named 'datasets'\"),\n",
" ('user',\n",
" 'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n",
" ('assistant',\n",
" \"Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset \\n Imports: import torch\\nfrom transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset \\n Code: # Load the model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = LlamaForCausalLM.from_pretrained(model_name)\\ntokenizer = LlamaTokenizer.from_pretrained(model_name)\\n\\n# Load the dataset\\nmol_instructions = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_dataset = mol_instructions.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n output_dir='./results',\\n evaluation_strategy='epoch',\\n learning_rate=2e-5,\\n per_device_train_batch_size=4,\\n per_device_eval_batch_size=4,\\n num_train_epochs=3,\\n weight_decay=0.01,\\n logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n model=model,\\n args=training_args,\\n train_dataset=tokenized_dataset['train'],\\n eval_dataset=tokenized_dataset['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\"),\n",
" ('user',\n",
" \"Your solution failed the import test: No module named 'datasets'\")],\n",
" 'generation': code(prefix='Fine-tuning Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset', imports='import torch\\nfrom transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments\\nfrom datasets import load_dataset', code=\"# Load the model and tokenizer\\nmodel_name = 'gradientai/Llama-3-8B-Instruct-Gradient-1048k'\\nmodel = LlamaForCausalLM.from_pretrained(model_name)\\ntokenizer = LlamaTokenizer.from_pretrained(model_name)\\n\\n# Load the dataset\\nmol_instructions = load_dataset('zjunlp/Mol-Instructions')\\n\\n# Tokenize the dataset\\ndef tokenize_function(examples):\\n return tokenizer(examples['text'], padding='max_length', truncation=True)\\n\\ntokenized_dataset = mol_instructions.map(tokenize_function, batched=True)\\n\\n# Set training arguments\\ntraining_args = TrainingArguments(\\n output_dir='./results',\\n evaluation_strategy='epoch',\\n learning_rate=2e-5,\\n per_device_train_batch_size=4,\\n per_device_eval_batch_size=4,\\n num_train_epochs=3,\\n weight_decay=0.01,\\n logging_dir='./logs',\\n)\\n\\n# Initialize the Trainer\\ntrainer = Trainer(\\n model=model,\\n args=training_args,\\n train_dataset=tokenized_dataset['train'],\\n eval_dataset=tokenized_dataset['validation'],\\n)\\n\\n# Fine-tune the model\\ntrainer.train()\", description='Schema for code solutions to questions about Python.'),\n",
" 'iterations': 3}"
]
},
"metadata": {},
"execution_count": 10
}
],
"source": [
"question = \"How can I fine-tune Llama-3-8B-Instruct-Gradient-1048k with the Mol-Instructions dataset?\"\n",
"app.invoke({\"messages\": [(\"user\", question)], \"iterations\": 0})"
]
},
{
"cell_type": "markdown",
"id": "744f48a5-9ad3-4342-899f-7dd4266a9a15",
"metadata": {
"id": "744f48a5-9ad3-4342-899f-7dd4266a9a15"
},
"source": [
"## Eval"
]
},
{
"cell_type": "markdown",
"id": "89852874-b538-4c8d-a4c3-1d68302db492",
"metadata": {
"id": "89852874-b538-4c8d-a4c3-1d68302db492"
},
"source": [
"[Here](https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d) is a public dataset of LCEL questions.\n",
"\n",
"I saved this as `test-LCEL-code-gen`.\n",
"\n",
"You can also find the csv [here](https://github.com/langchain-ai/lcel-teacher/blob/main/eval/eval.csv)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "678e8954-56b5-4cc6-be26-f7f2a060b242",
"metadata": {
"id": "678e8954-56b5-4cc6-be26-f7f2a060b242"
},
"outputs": [],
"source": [
"import langsmith\n",
"\n",
"client = langsmith.Client()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ef7cf662-7a6f-4dee-965c-6309d4045feb",
"metadata": {
"id": "ef7cf662-7a6f-4dee-965c-6309d4045feb"
},
"outputs": [],
"source": [
"# Clone the dataset to your tenant to use it\n",
"public_dataset = (\n",
" \"https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d\"\n",
")\n",
"client.clone_public_dataset(public_dataset)"
]
},
{
"cell_type": "markdown",
"id": "9d171396-022b-47ec-a741-c782aff9fdae",
"metadata": {
"id": "9d171396-022b-47ec-a741-c782aff9fdae"
},
"source": [
"Custom evals."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "455a34ea-52cb-4ae5-9f4a-7e4a08cd0c09",
"metadata": {
"id": "455a34ea-52cb-4ae5-9f4a-7e4a08cd0c09"
},
"outputs": [],
"source": [
"from langsmith.schemas import Example, Run\n",
"\n",
"\n",
"def check_import(run: Run, example: Example) -> dict:\n",
" imports = run.outputs.get(\"imports\")\n",
" try:\n",
" exec(imports)\n",
" return {\"key\": \"import_check\", \"score\": 1}\n",
" except:\n",
" return {\"key\": \"import_check\", \"score\": 0}\n",
"\n",
"\n",
"def check_execution(run: Run, example: Example) -> dict:\n",
" imports = run.outputs.get(\"imports\")\n",
" code = run.outputs.get(\"code\")\n",
" try:\n",
" exec(imports + \"\\n\" + code)\n",
" return {\"key\": \"code_execution_check\", \"score\": 1}\n",
" except:\n",
" return {\"key\": \"code_execution_check\", \"score\": 0}"
]
},
{
"cell_type": "markdown",
"id": "c90bf261-0d94-4779-bbde-c76adeefe3d7",
"metadata": {
"id": "c90bf261-0d94-4779-bbde-c76adeefe3d7"
},
"source": [
"Compare LangGraph to Context Stuffing."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c8fa6bcb-b245-4422-b79a-582cd8a7d7ea",
"metadata": {
"id": "c8fa6bcb-b245-4422-b79a-582cd8a7d7ea"
},
"outputs": [],
"source": [
"def predict_base_case(example: dict):\n",
" \"\"\"Context stuffing\"\"\"\n",
" solution = code_gen_chain.invoke(\n",
" {\"context\": concatenated_content, \"messages\": [(\"user\", example[\"question\"])]}\n",
" )\n",
" solution_structured = structured_code_formatter.invoke([(\"code\", solution)])\n",
" return {\"imports\": solution_structured.imports, \"code\": solution_structured.code}\n",
"\n",
"\n",
"def predict_langgraph(example: dict):\n",
" \"\"\"LangGraph\"\"\"\n",
" graph = app.invoke({\"messages\": [(\"user\", example[\"question\"])], \"iterations\": 0})\n",
" solution = graph[\"generation\"]\n",
" return {\"imports\": solution.imports, \"code\": solution.code}"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d9c57468-97f6-47d6-a5e9-c09b53bfdd83",
"metadata": {
"id": "d9c57468-97f6-47d6-a5e9-c09b53bfdd83"
},
"outputs": [],
"source": [
"from langsmith.evaluation import evaluate\n",
"\n",
"# Evaluator\n",
"code_evalulator = [check_import, check_execution]\n",
"\n",
"# Dataset\n",
"dataset_name = \"test-LCEL-code-gen\""
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2dacccf0-d73f-4017-aaf0-9806ffe5bd2c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 371
},
"id": "2dacccf0-d73f-4017-aaf0-9806ffe5bd2c",
"outputId": "d3a9493a-412d-41c3-a7cb-1c5cb58662f0"
},
"outputs": [
{
"output_type": "error",
"ename": "LangSmithNotFoundError",
"evalue": "Dataset test-LCEL-code-gen not found",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mLangSmithNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-16-c42adff5de4e>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Run base case\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m experiment_results_ = evaluate(\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mpredict_base_case\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mevaluators\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcode_evalulator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/evaluation/_runner.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(target, data, evaluators, summary_evaluators, metadata, experiment_prefix, description, max_concurrency, client, blocking)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0mView\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mevaluation\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mexperiment\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 234\u001b[0m \"\"\" # noqa: E501\n\u001b[0;32m--> 235\u001b[0;31m return _evaluate(\n\u001b[0m\u001b[1;32m 236\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/evaluation/_runner.py\u001b[0m in \u001b[0;36m_evaluate\u001b[0;34m(target, data, evaluators, summary_evaluators, metadata, experiment_prefix, description, max_concurrency, client, blocking, experiment)\u001b[0m\n\u001b[1;32m 789\u001b[0m \u001b[0mruns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mruns\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[0;31m# Create or resolve the experiment.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 791\u001b[0;31m ).start()\n\u001b[0m\u001b[1;32m 792\u001b[0m \u001b[0mcache_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mls_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_cache_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 793\u001b[0m cache_path = (\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/evaluation/_runner.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1054\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1055\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_ExperimentManager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1056\u001b[0;31m \u001b[0mfirst_example\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitertools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mislice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1057\u001b[0m \u001b[0mproject\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_project\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfirst_example\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1058\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_print_experiment_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproject\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfirst_example\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/client.py\u001b[0m in \u001b[0;36mlist_examples\u001b[0;34m(self, dataset_id, dataset_name, example_ids, as_of, splits, inline_s3_urls, limit, metadata, **kwargs)\u001b[0m\n\u001b[1;32m 3144\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3145\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mdataset_name\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3146\u001b[0;31m \u001b[0mdataset_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdataset_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3147\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dataset\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3148\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/utils.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;34mf\" {', '.join(invalid_group_names)}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m )\n\u001b[0;32m--> 101\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/langsmith/client.py\u001b[0m in \u001b[0;36mread_dataset\u001b[0;34m(self, dataset_name, dataset_id)\u001b[0m\n\u001b[1;32m 2400\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2401\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2402\u001b[0;31m raise ls_utils.LangSmithNotFoundError(\n\u001b[0m\u001b[1;32m 2403\u001b[0m \u001b[0;34mf\"Dataset {dataset_name} not found\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2404\u001b[0m )\n",
"\u001b[0;31mLangSmithNotFoundError\u001b[0m: Dataset test-LCEL-code-gen not found"
]
}
],
"source": [
"# Run base case\n",
"experiment_results_ = evaluate(\n",
" predict_base_case,\n",
" data=dataset_name,\n",
" evaluators=code_evalulator,\n",
" experiment_prefix=f\"test-without-langgraph-{expt_llm}\",\n",
" max_concurrency=2,\n",
" metadata={\n",
" \"llm\": expt_llm,\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71d90f9e-9dad-410c-a709-093d275029ae",
"metadata": {
"id": "71d90f9e-9dad-410c-a709-093d275029ae"
},
"outputs": [],
"source": [
"# Run with langgraph\n",
"experiment_results = evaluate(\n",
" predict_langgraph,\n",
" data=dataset_name,\n",
" evaluators=code_evalulator,\n",
" experiment_prefix=f\"test-with-langgraph-{expt_llm}-{flag}\",\n",
" max_concurrency=2,\n",
" metadata={\n",
" \"llm\": expt_llm,\n",
" \"feedback\": flag,\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d69da747-b4ea-455d-9314-60c3d9d30549",
"metadata": {
"id": "d69da747-b4ea-455d-9314-60c3d9d30549"
},
"source": [
"`Results:`\n",
"\n",
"* `LangGraph outperforms base case`: adding re-try loop improve performance\n",
"* `Reflection did not help`: reflection prior to re-try regression vs just passing errors directly back to the LLM\n",
"* `GPT-4 outperforms Claude3`: Claude3 had 3 and 1 run fail due to tool-use error for Opus and Haiku, respectively\n",
"\n",
"https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a42333c3-c098-4576-ae2a-0258de64ece2",
"metadata": {
"id": "a42333c3-c098-4576-ae2a-0258de64ece2"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
},
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}