148 lines (148 with data), 4.9 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "loading_with_webdatasets",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "8SxP40j_pX2e"
},
"source": [
"Method for loading training data that uses the webdatasets library and saves us a ton of disk and ram issues."
]
},
{
"cell_type": "code",
"metadata": {
"id": "65WtZnBCobLf"
},
"source": [
"!pip install webdataset"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "bEQHUYdVokOX"
},
"source": [
"import nibabel as nb\r\n",
"import numpy as np\r\n",
"from io import BytesIO\r\n",
"from nibabel import FileHolder, Nifti1Image\r\n",
"import os\r\n",
"import torch\r\n",
"from skimage import transform\r\n",
"import webdataset as wds"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uAyoX4ylnyWK"
},
"source": [
"from google.colab import drive\r\n",
"drive.mount('/content/drive', force_remount=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9fy3LDQ2ouKR"
},
"source": [
"dataset = wds.Dataset(\"./drive/MyDrive/macai_datasets/brats_training.tar.gz\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "btQwqJ5yoaQI"
},
"source": [
"import nibabel as nb\r\n",
"import numpy as np\r\n",
"from io import BytesIO\r\n",
"from nibabel import FileHolder, Nifti1Image\r\n",
"import os\r\n",
"import torch\r\n",
"from skimage import transform\r\n",
"import webdataset as wds\r\n",
"\r\n",
"\r\n",
"train_dataset = wds.Dataset(\"./drive/MyDrive/macai_datasets/brats_training.tar.gz\")\r\n",
"eval_dataset = wds.Dataset(\"./drive/MyDrive/macai_datasets/brats_training.tar.gz\")\r\n",
"\r\n",
"\r\n",
"\r\n",
"def col_img(batch):\r\n",
" bytes_data_list = [list(batch[i].items())[1][1] for i in range(5)] \r\n",
" \r\n",
" bb = BytesIO(bytes_data_list[0])\r\n",
" fh = FileHolder(fileobj=bb)\r\n",
" f_flair = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata()\r\n",
"\r\n",
" bb = BytesIO(bytes_data_list[1])\r\n",
" fh = FileHolder(fileobj=bb)\r\n",
" f_seg = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata()\r\n",
"\r\n",
" bb = BytesIO(bytes_data_list[2])\r\n",
" fh = FileHolder(fileobj=bb)\r\n",
" f_t1 = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata()\r\n",
"\r\n",
" bb = BytesIO(bytes_data_list[3])\r\n",
" fh = FileHolder(fileobj=bb)\r\n",
" f_t1ce = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata() \r\n",
"\r\n",
" bb = BytesIO(bytes_data_list[4])\r\n",
" fh = FileHolder(fileobj=bb)\r\n",
" f_t2 = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata() \r\n",
"\r\n",
" f_t1 = transform.resize(f_t1, [320, 400, 320])\r\n",
" f_t2 = transform.resize(f_t2, [320, 400, 320])\r\n",
" f_t1ce = transform.resize(f_t1ce, [320, 400, 320])\r\n",
" f_flair = transform.resize(f_flair, [320, 400, 320])\r\n",
" f_seg = transform.resize(f_seg, [320, 400, 320])\r\n",
" return [torch.tensor(np.stack([f_t1, f_t1ce, f_t2, f_flair])), torch.tensor(f_seg)]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KYygUfG30MmQ"
},
"source": [
"#Very janky way of seperating into train and validation \r\n",
"eval_dataset.select(lambda x : int(list(x.items())[0][1].split('/')[-1][17]) == 3)\r\n",
"train_dataset.select(lambda x : int(list(x.items())[0][1].split('/')[-1][17]) < 3)\r\n",
"\r\n",
"train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=5,collate_fn=col_img)\r\n",
"eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=5,collate_fn=col_img)"
],
"execution_count": null,
"outputs": []
}
]
}