116 lines (115 with data), 3.2 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from wearsed.dataset.WearSEDDataset_vanilla import WearSEDDataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_random_sequence(signals, labels, max_time, seq_length):\n",
" start = torch.randint(0, max_time, (1,))\n",
" end = start + seq_length\n",
" return signals[:, start:end], labels[start:end]\n",
"\n",
"def get_batch(signals, labels, batch_size, seq_length):\n",
" max_time = signals.shape[1] - seq_length\n",
"\n",
" tries = 0\n",
" batch_signals = []\n",
" batch_labels = []\n",
" for i in range(batch_size):\n",
" signals_seq, labels_seq = get_random_sequence(signals, labels, max_time, seq_length)\n",
" if i < 0.8*batch_size or tries < 20: # If there are not enough sequences with events and we haven't tried long enough\n",
" if labels_seq.sum() == 0: # ..check if this is a sequence without events and if so, roll again\n",
" signals_seq, labels_seq = get_random_sequence(signals, labels, max_time, seq_length) # TODO could still be negative sample\n",
" batch_signals.append(signals_seq)\n",
" batch_labels.append(labels_seq)\n",
" tries += 1\n",
" continue\n",
" batch_signals.append(signals_seq)\n",
" batch_labels.append(labels_seq)\n",
" return torch.stack(batch_signals), torch.stack(batch_labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = WearSEDDataset(mesaid_path='../wearsed/dataset/data_ids/', signals_to_read=['SpO2', 'Pleth'], preprocess=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"42 / 64\n",
"57 / 64\n",
"63 / 64\n",
"37 / 64\n",
"37 / 64\n",
"41 / 64\n",
"62 / 64\n",
"54 / 64\n",
"49 / 64\n",
"48 / 64\n",
"25 / 64\n",
"54 / 64\n",
"52 / 64\n",
"56 / 64\n",
"34 / 64\n",
"56 / 64\n",
"57 / 64\n",
"44 / 64\n",
"37 / 64\n",
"59 / 64\n"
]
}
],
"source": [
"for i in range(20):\n",
" signals, labels = dataset[i]\n",
" _, y = get_batch(signals, labels, batch_size=64, seq_length=600)\n",
" print((y.sum(dim=1)>0).sum().item(), '/', y.shape[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "master",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}