[d869a2]: / Notebooks / 12_random_split_test.ipynb

Download this file

370 lines (369 with data), 16.0 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wearsed.dataset.WearSEDDataset import WearSEDDataset\n",
    "from wearsed.models.vae.VAE import VAE, vae_loss\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "\n",
    "# Dataset\n",
    "full_dataset = WearSEDDataset(signals_to_read=['Pleth'])\n",
    "train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 1\n",
      "1 2\n",
      "2 6\n",
      "3 12\n",
      "4 14\n",
      "5 16\n",
      "6 21\n",
      "7 27\n",
      "8 28\n",
      "9 33\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i, full_dataset[i].id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 3016\n",
      "1 1844\n",
      "2 2309\n",
      "3 949\n",
      "4 5295\n",
      "5 4588\n",
      "6 3703\n",
      "7 837\n",
      "8 2227\n",
      "9 5369\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i, train_dataset[i].id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0         -0.025406\n",
       "1         -0.024033\n",
       "2         -0.022965\n",
       "3         -0.022049\n",
       "4         -0.021286\n",
       "             ...   \n",
       "9057019    0.025254\n",
       "9057020    0.020981\n",
       "9057021    0.016098\n",
       "9057022    0.010758\n",
       "9057023    0.004959\n",
       "Length: 9057024, dtype: float64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0].psg['Pleth']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.rand(100, 256)  # Example dataset with 1000 samples\n",
    "train_loader = torch.utils.data.DataLoader(data, batch_size=16, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 256])\n",
      "torch.Size([16, 256])\n",
      "torch.Size([16, 256])\n",
      "torch.Size([16, 256])\n",
      "torch.Size([16, 256])\n",
      "torch.Size([16, 256])\n",
      "torch.Size([4, 256])\n"
     ]
    }
   ],
   "source": [
    "for batch in train_loader:\n",
    "    print(batch.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.Tensor([1,2,3])\n",
    "b = torch.Tensor([4,5,6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2., 3.],\n",
       "        [4., 5., 6.]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack([a,b])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "expected Tensor as element 0 in argument 0, but got numpy.ndarray",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m a \u001b[38;5;241m=\u001b[39m full[\u001b[38;5;241m3\u001b[39m:\u001b[38;5;241m6\u001b[39m]\n\u001b[1;32m      3\u001b[0m b \u001b[38;5;241m=\u001b[39m full[\u001b[38;5;241m7\u001b[39m:\u001b[38;5;241m10\u001b[39m]\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstack\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43ma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mTypeError\u001b[0m: expected Tensor as element 0 in argument 0, but got numpy.ndarray"
     ]
    }
   ],
   "source": [
    "full = pd.Series([1,2,3,4,5,6,7,8,9,10,11,12])\n",
    "a = full[3:6]\n",
    "b = full[7:10]\n",
    "torch.stack([a.values, b.values])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wearsed.training.vae.train_vae import get_batch\n",
    "from wearsed.dataset.WearSEDDataset import WearSEDDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_dataset = WearSEDDataset(signals_to_read=['Pleth'])\n",
    "r1 = full_dataset[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0141,  0.0150,  0.0159,  ...,  0.0555,  0.0555,  0.0555],\n",
       "        [ 0.0773,  0.0780,  0.0787,  ...,  0.1037,  0.1049,  0.1061],\n",
       "        [-0.1217, -0.1199, -0.1180,  ..., -0.0356, -0.0307, -0.0263],\n",
       "        [ 0.0469,  0.0480,  0.0494,  ...,  0.0454,  0.0454,  0.0452]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_batch(r1, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "signal = r1.psg['Pleth']\n",
    "max_time = len(signal) - 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4486851"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(torch.randint(0, max_time, (4,)))[0].item()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "cannot do slice indexing on RangeIndex with these indexers [5037322] of type Tensor",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[19], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m timepoint \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m, max_time, (\u001b[38;5;241m4\u001b[39m,))):\n\u001b[0;32m----> 3\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[43msignal\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtimepoint\u001b[49m\u001b[43m:\u001b[49m\u001b[43mtimepoint\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m256\u001b[39;49m\u001b[43m]\u001b[49m)\n\u001b[1;32m      4\u001b[0m     \u001b[38;5;66;03m#batch.append(torch.Tensor(signal[timepoint:timepoint+256].values))\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/master/lib/python3.12/site-packages/pandas/core/series.py:1146\u001b[0m, in \u001b[0;36mSeries.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   1142\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_values_tuple(key)\n\u001b[1;32m   1144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mslice\u001b[39m):\n\u001b[1;32m   1145\u001b[0m     \u001b[38;5;66;03m# Do slice check before somewhat-costly is_bool_indexer\u001b[39;00m\n\u001b[0;32m-> 1146\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem_slice\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m com\u001b[38;5;241m.\u001b[39mis_bool_indexer(key):\n\u001b[1;32m   1149\u001b[0m     key \u001b[38;5;241m=\u001b[39m check_bool_indexer(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindex, key)\n",
      "File \u001b[0;32m~/miniconda3/envs/master/lib/python3.12/site-packages/pandas/core/generic.py:4349\u001b[0m, in \u001b[0;36mNDFrame._getitem_slice\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   4344\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   4345\u001b[0m \u001b[38;5;124;03m__getitem__ for the case where the key is a slice object.\u001b[39;00m\n\u001b[1;32m   4346\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   4347\u001b[0m \u001b[38;5;66;03m# _convert_slice_indexer to determine if this slice is positional\u001b[39;00m\n\u001b[1;32m   4348\u001b[0m \u001b[38;5;66;03m#  or label based, and if the latter, convert to positional\u001b[39;00m\n\u001b[0;32m-> 4349\u001b[0m slobj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_convert_slice_indexer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkind\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgetitem\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m   4350\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(slobj, np\u001b[38;5;241m.\u001b[39mndarray):\n\u001b[1;32m   4351\u001b[0m     \u001b[38;5;66;03m# reachable with DatetimeIndex\u001b[39;00m\n\u001b[1;32m   4352\u001b[0m     indexer \u001b[38;5;241m=\u001b[39m lib\u001b[38;5;241m.\u001b[39mmaybe_indices_to_slice(\n\u001b[1;32m   4353\u001b[0m         slobj\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mintp, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m), \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m   4354\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/master/lib/python3.12/site-packages/pandas/core/indexes/base.py:4248\u001b[0m, in \u001b[0;36mIndex._convert_slice_indexer\u001b[0;34m(self, key, kind)\u001b[0m\n\u001b[1;32m   4245\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m key\n\u001b[1;32m   4246\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miu\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m   4247\u001b[0m     \u001b[38;5;66;03m# Note: these checks are redundant if we know is_index_slice\u001b[39;00m\n\u001b[0;32m-> 4248\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_indexer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mslice\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstart\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgetitem\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m   4249\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_indexer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mslice\u001b[39m\u001b[38;5;124m\"\u001b[39m, key\u001b[38;5;241m.\u001b[39mstop, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgetitem\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   4250\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_indexer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mslice\u001b[39m\u001b[38;5;124m\"\u001b[39m, key\u001b[38;5;241m.\u001b[39mstep, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgetitem\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/envs/master/lib/python3.12/site-packages/pandas/core/indexes/base.py:6697\u001b[0m, in \u001b[0;36mIndex._validate_indexer\u001b[0;34m(self, form, key, kind)\u001b[0m\n\u001b[1;32m   6692\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   6693\u001b[0m \u001b[38;5;124;03mIf we are positional indexer, validate that we have appropriate\u001b[39;00m\n\u001b[1;32m   6694\u001b[0m \u001b[38;5;124;03mtyped bounds must be an integer.\u001b[39;00m\n\u001b[1;32m   6695\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   6696\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mis_int_or_none(key):\n\u001b[0;32m-> 6697\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_raise_invalid_indexer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mform\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/master/lib/python3.12/site-packages/pandas/core/indexes/base.py:4301\u001b[0m, in \u001b[0;36mIndex._raise_invalid_indexer\u001b[0;34m(self, form, key, reraise)\u001b[0m\n\u001b[1;32m   4299\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m reraise \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m lib\u001b[38;5;241m.\u001b[39mno_default:\n\u001b[1;32m   4300\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mreraise\u001b[39;00m\n\u001b[0;32m-> 4301\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg)\n",
      "\u001b[0;31mTypeError\u001b[0m: cannot do slice indexing on RangeIndex with these indexers [5037322] of type Tensor"
     ]
    }
   ],
   "source": [
    "batch = []\n",
    "for timepoint in list(torch.randint(0, max_time, (4,))):\n",
    "    print(signal[timepoint:timepoint+256])\n",
    "    #batch.append(torch.Tensor(signal[timepoint:timepoint+256].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.)\n"
     ]
    }
   ],
   "source": [
    "a = torch.Tensor([\n",
    "    [1,2,3],\n",
    "    [5,6,7]])\n",
    "b = torch.Tensor([\n",
    "    [1,2,1],\n",
    "    [5,7,8]])\n",
    "\n",
    "print(F.mse_loss(a, b, reduction='sum')/(2*3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "master",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}