[7e250a]: / src / hint / notebooks / data_preparation.ipynb

Download this file

194 lines (193 with data), 7.7 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "38df5ac2-ec56-4cd6-8d87-33ea12b9c72d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/varadi_kristof/llms-for-trials/src/hint\n"
     ]
    }
   ],
   "source": [
    "%cd ../"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7deef9de-5bc5-4f84-b4c2-1139bc766b38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e98949ff-05af-4fd3-b708-d5797138be77",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"./data/trial\"\n",
    "\n",
    "p1_train = pd.read_csv(f\"{data_dir}/phase_I_train.csv\")\n",
    "p2_train = pd.read_csv(f\"{data_dir}/phase_II_train.csv\")\n",
    "p3_train = pd.read_csv(f\"{data_dir}/phase_III_train.csv\")\n",
    "\n",
    "p1_test = pd.read_csv(f\"{data_dir}/phase_I_test.csv\")\n",
    "p2_test = pd.read_csv(f\"{data_dir}/phase_II_test.csv\")\n",
    "p3_test = pd.read_csv(f\"{data_dir}/phase_III_test.csv\")\n",
    "\n",
    "p1_valid = pd.read_csv(f\"{data_dir}/phase_I_valid.csv\")\n",
    "p2_valid = pd.read_csv(f\"{data_dir}/phase_II_valid.csv\")\n",
    "p3_valid = pd.read_csv(f\"{data_dir}/phase_III_valid.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2485e49f-468b-47f9-a5a7-3b70fd74bb51",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_merged = pd.concat([p1_train, p2_train, p3_train], ignore_index=True)\n",
    "test_merged = pd.concat([p1_test, p2_test, p3_test], ignore_index=True)\n",
    "valid_merged = pd.concat([p1_valid, p2_valid, p3_valid], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "281e659e-cfdd-4ce3-a593-e5f50c8f36eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "label\n",
       "1        4527\n",
       "0        3616\n",
       "dtype: int64"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_merged[[\"label\"]].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "1ad92618-557b-4baa-a9d2-d92e024bba6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file = f\"{data_dir}/train.csv\"\n",
    "valid_file = f\"{data_dir}/valid.csv\"\n",
    "test_file = f\"{data_dir}/test.csv\"\n",
    "\n",
    "train_merged.to_csv(train_file, index=False)\n",
    "valid_merged.to_csv(valid_file, index=False)\n",
    "test_merged.to_csv(test_file, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "be20385d-febb-46b8-9701-3fe15797fb9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_train = pd.concat([p1_train[\"smiless\"], p2_train[\"smiless\"], p3_train[\"smiless\"]], axis=0)\n",
    "merged_train.reset_index(drop=True, inplace=True)\n",
    "\n",
    "merged_test = pd.concat([p1_test[\"smiless\"], p2_test[\"smiless\"], p3_test[\"smiless\"]], axis=0)\n",
    "merged_test.reset_index(drop=True, inplace=True)\n",
    "\n",
    "merged_valid = pd.concat([p1_valid[\"smiless\"], p2_valid[\"smiless\"], p3_valid[\"smiless\"]], axis=0)\n",
    "merged_valid.reset_index(drop=True, inplace=True)\n",
    "\n",
    "def explode_list(row):\n",
    "    smiles_list = ast.literal_eval(row)\n",
    "    return smiles_list\n",
    "\n",
    "exploded_train = merged_train.apply(explode_list).explode().reset_index(drop=True)\n",
    "exploded_test = merged_test.apply(explode_list).explode().reset_index(drop=True)\n",
    "exploded_valid = merged_valid.apply(explode_list).explode().reset_index(drop=True)\n",
    "\n",
    "exploded_train.to_csv(train_file, index=False, header=False)\n",
    "exploded_test.to_csv(test_file, index=False, header=False)\n",
    "exploded_valid.to_csv(valid_file, index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b8ae6755-73dd-43ce-91fd-83b7fe20f738",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/varadi_kristof/llms-for-trials/src/hint/toxicity/multitask-toxicity/SE_featurization\n",
      "ArgumentParser(prog='save_embeddings.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)\n",
      "Namespace(config_load='models/config.nb', device='cuda', input_file='../../../data/smiles_train.csv', model='translation', model_load='models/model.pt', n_batch=65, output_file='../../../data/smiles_embed_train.pt', vocab_load='models/vocab.nb')\n",
      "Model loaded.\n",
      "Calculating Embeddings..\n",
      "100%|█████████████████████████████████████████| 267/267 [00:15<00:00, 17.02it/s]\n",
      "ArgumentParser(prog='save_embeddings.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)\n",
      "Namespace(config_load='models/config.nb', device='cuda', input_file='../../../data/smiles_test.csv', model='translation', model_load='models/model.pt', n_batch=65, output_file='../../../data/smiles_embed_test.pt', vocab_load='models/vocab.nb')\n",
      "Model loaded.\n",
      "Calculating Embeddings..\n",
      "100%|█████████████████████████████████████████| 104/104 [00:06<00:00, 15.61it/s]\n",
      "ArgumentParser(prog='save_embeddings.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)\n",
      "Namespace(config_load='models/config.nb', device='cuda', input_file='../../../data/smiles_valid.csv', model='translation', model_load='models/model.pt', n_batch=65, output_file='../../../data/smiles_embed_valid.pt', vocab_load='models/vocab.nb')\n",
      "Model loaded.\n",
      "Calculating Embeddings..\n",
      "100%|███████████████████████████████████████████| 31/31 [00:30<00:00,  1.01it/s]\n"
     ]
    }
   ],
   "source": [
    "%cd ./toxicity/multitask-toxicity/SE_featurization\n",
    "!python -m scripts.save_embeddings --input_file ../../../data/smiles_train.csv  --output_file ../../../data/smiles_embed_train.pt --model translation --device cuda --model_load models/model.pt --vocab_load models/vocab.nb --config_load models/config.nb --n_batch 65\n",
    "!python -m scripts.save_embeddings --input_file ../../../data/smiles_test.csv  --output_file ../../../data/smiles_embed_test.pt --model translation --device cuda --model_load models/model.pt --vocab_load models/vocab.nb --config_load models/config.nb --n_batch 65\n",
    "!python -m scripts.save_embeddings --input_file ../../../data/smiles_valid.csv  --output_file ../../../data/smiles_embed_valid.pt --model translation --device cuda --model_load models/model.pt --vocab_load models/vocab.nb --config_load models/config.nb --n_batch 65"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}