[6c1f44]: / synthetic_data / synthetic_data_generation_GPT.ipynb

Download this file

153 lines (152 with data), 4.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import openai\n",
    "import os\n",
    "from pathlib import Path\n",
    "import json\n",
    "import os\n",
    "import re\n",
    "from collections import Counter"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPT API Calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Paste your API key here\n",
    "AA = \"\"\n",
    "openai.api_key = AA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_list_items(text):\n",
    "    # Find all list items using regex, including those on the same line\n",
    "    list_items = re.findall(r'\\d+\\.\\s*(.+?)(?=\\s*\\d+\\.|\\n|$)', text, flags=re.DOTALL)\n",
    "\n",
    "    # Remove the numbers from each list item\n",
    "    items_without_numbers = [re.sub(r'^\\s*\\d+\\.\\s*', '', item) for item in list_items]\n",
    "\n",
    "    # Return the extracted list items as a single string, joined by newlines\n",
    "    return '\\n'.join(items_without_numbers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gpt3_response(messages, temperature=0, max_tokens=1000, top_p=1, frequency_penalty=0, presence_penalty=0, best_of=1):\n",
    "    \"\"\"\n",
    "    Func to make OpenAI API calls.\n",
    "    \"\"\"\n",
    "    model = \"gpt-3.5-turbo\"\n",
    "    response = openai.ChatCompletion.create(\n",
    "    model=model,\n",
    "    messages=messages,\n",
    "    temperature=temperature,\n",
    "    max_tokens=max_tokens,\n",
    "    top_p=top_p,\n",
    "    frequency_penalty=frequency_penalty,\n",
    "    presence_penalty=presence_penalty\n",
    "    )['choices'][0]['message']['content']\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_synthetic_data(input_path: str, output_path: str):\n",
    "    # Load input prompts from JSON file\n",
    "    with open(input_path, encoding='utf8') as j:\n",
    "        prompts = json.load(j)\n",
    "    \n",
    "    # Generate synthetic data for each prompt using GPT-3\n",
    "    data_list = [{\n",
    "        'prompt_id': prompt['prompt_id'],\n",
    "        'label': prompt['prompt_id'].split('_')[1].split('-')[0],\n",
    "        'adverse': prompt['prompt_id'].split('_')[1].split('-')[1],\n",
    "        'writer': input_path.split('/')[-1].split('.')[0],\n",
    "        'synthetic_data': gpt3_response(messages=prompt['prompt'], temperature=0.3, frequency_penalty=1.9, presence_penalty=0.9, max_tokens=4000) # Increase temperature and max tokens for longer and more varied output\n",
    "        } for prompt in prompts]\n",
    "    data = pd.DataFrame(data_list)\n",
    "    \n",
    "    # If the output file already exists, append the new data to it\n",
    "    if os.path.isfile(output_path):\n",
    "        df = pd.read_csv(output_path)\n",
    "        df = pd.concat([df, data], ignore_index=True)\n",
    "    else:\n",
    "        df = pd.DataFrame(data)\n",
    "    \n",
    "    # Extract list items from the GPT-3 response and store in a new column\n",
    "    df['formatted_response'] = df['synthetic_data'].apply(extract_list_items)\n",
    "    \n",
    "    # Write the DataFrame to the output file\n",
    "    df.to_csv(output_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_path = ''\n",
    "output_path = ''\n",
    "create_synthetic_data(input_path, output_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bwh_models",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1d28981372adf6627f31be41cb660cbd4a9050764ba0e8d7de31ed03d8776b3e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}