|
a |
|
b/Notebooks/old/loading_with_webdatasets.ipynb |
|
|
1 |
{ |
|
|
2 |
"nbformat": 4, |
|
|
3 |
"nbformat_minor": 0, |
|
|
4 |
"metadata": { |
|
|
5 |
"colab": { |
|
|
6 |
"name": "loading_with_webdatasets", |
|
|
7 |
"provenance": [], |
|
|
8 |
"collapsed_sections": [] |
|
|
9 |
}, |
|
|
10 |
"kernelspec": { |
|
|
11 |
"name": "python3", |
|
|
12 |
"display_name": "Python 3" |
|
|
13 |
} |
|
|
14 |
}, |
|
|
15 |
"cells": [ |
|
|
16 |
{ |
|
|
17 |
"cell_type": "markdown", |
|
|
18 |
"metadata": { |
|
|
19 |
"id": "8SxP40j_pX2e" |
|
|
20 |
}, |
|
|
21 |
"source": [ |
|
|
22 |
"Method for loading training data that uses the webdatasets library and saves us a ton of disk and ram issues." |
|
|
23 |
] |
|
|
24 |
}, |
|
|
25 |
{ |
|
|
26 |
"cell_type": "code", |
|
|
27 |
"metadata": { |
|
|
28 |
"id": "65WtZnBCobLf" |
|
|
29 |
}, |
|
|
30 |
"source": [ |
|
|
31 |
"!pip install webdataset" |
|
|
32 |
], |
|
|
33 |
"execution_count": null, |
|
|
34 |
"outputs": [] |
|
|
35 |
}, |
|
|
36 |
{ |
|
|
37 |
"cell_type": "code", |
|
|
38 |
"metadata": { |
|
|
39 |
"id": "bEQHUYdVokOX" |
|
|
40 |
}, |
|
|
41 |
"source": [ |
|
|
42 |
"import nibabel as nb\r\n", |
|
|
43 |
"import numpy as np\r\n", |
|
|
44 |
"from io import BytesIO\r\n", |
|
|
45 |
"from nibabel import FileHolder, Nifti1Image\r\n", |
|
|
46 |
"import os\r\n", |
|
|
47 |
"import torch\r\n", |
|
|
48 |
"from skimage import transform\r\n", |
|
|
49 |
"import webdataset as wds" |
|
|
50 |
], |
|
|
51 |
"execution_count": null, |
|
|
52 |
"outputs": [] |
|
|
53 |
}, |
|
|
54 |
{ |
|
|
55 |
"cell_type": "code", |
|
|
56 |
"metadata": { |
|
|
57 |
"id": "uAyoX4ylnyWK" |
|
|
58 |
}, |
|
|
59 |
"source": [ |
|
|
60 |
"from google.colab import drive\r\n", |
|
|
61 |
"drive.mount('/content/drive', force_remount=True)" |
|
|
62 |
], |
|
|
63 |
"execution_count": null, |
|
|
64 |
"outputs": [] |
|
|
65 |
}, |
|
|
66 |
{ |
|
|
67 |
"cell_type": "code", |
|
|
68 |
"metadata": { |
|
|
69 |
"id": "9fy3LDQ2ouKR" |
|
|
70 |
}, |
|
|
71 |
"source": [ |
|
|
72 |
"dataset = wds.Dataset(\"./drive/MyDrive/macai_datasets/brats_training.tar.gz\")" |
|
|
73 |
], |
|
|
74 |
"execution_count": null, |
|
|
75 |
"outputs": [] |
|
|
76 |
}, |
|
|
77 |
{ |
|
|
78 |
"cell_type": "code", |
|
|
79 |
"metadata": { |
|
|
80 |
"id": "btQwqJ5yoaQI" |
|
|
81 |
}, |
|
|
82 |
"source": [ |
|
|
83 |
"import nibabel as nb\r\n", |
|
|
84 |
"import numpy as np\r\n", |
|
|
85 |
"from io import BytesIO\r\n", |
|
|
86 |
"from nibabel import FileHolder, Nifti1Image\r\n", |
|
|
87 |
"import os\r\n", |
|
|
88 |
"import torch\r\n", |
|
|
89 |
"from skimage import transform\r\n", |
|
|
90 |
"import webdataset as wds\r\n", |
|
|
91 |
"\r\n", |
|
|
92 |
"\r\n", |
|
|
93 |
"train_dataset = wds.Dataset(\"./drive/MyDrive/macai_datasets/brats_training.tar.gz\")\r\n", |
|
|
94 |
"eval_dataset = wds.Dataset(\"./drive/MyDrive/macai_datasets/brats_training.tar.gz\")\r\n", |
|
|
95 |
"\r\n", |
|
|
96 |
"\r\n", |
|
|
97 |
"\r\n", |
|
|
98 |
"def col_img(batch):\r\n", |
|
|
99 |
" bytes_data_list = [list(batch[i].items())[1][1] for i in range(5)] \r\n", |
|
|
100 |
" \r\n", |
|
|
101 |
" bb = BytesIO(bytes_data_list[0])\r\n", |
|
|
102 |
" fh = FileHolder(fileobj=bb)\r\n", |
|
|
103 |
" f_flair = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata()\r\n", |
|
|
104 |
"\r\n", |
|
|
105 |
" bb = BytesIO(bytes_data_list[1])\r\n", |
|
|
106 |
" fh = FileHolder(fileobj=bb)\r\n", |
|
|
107 |
" f_seg = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata()\r\n", |
|
|
108 |
"\r\n", |
|
|
109 |
" bb = BytesIO(bytes_data_list[2])\r\n", |
|
|
110 |
" fh = FileHolder(fileobj=bb)\r\n", |
|
|
111 |
" f_t1 = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata()\r\n", |
|
|
112 |
"\r\n", |
|
|
113 |
" bb = BytesIO(bytes_data_list[3])\r\n", |
|
|
114 |
" fh = FileHolder(fileobj=bb)\r\n", |
|
|
115 |
" f_t1ce = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata() \r\n", |
|
|
116 |
"\r\n", |
|
|
117 |
" bb = BytesIO(bytes_data_list[4])\r\n", |
|
|
118 |
" fh = FileHolder(fileobj=bb)\r\n", |
|
|
119 |
" f_t2 = Nifti1Image.from_file_map({'header': fh, 'image': fh}).get_fdata() \r\n", |
|
|
120 |
"\r\n", |
|
|
121 |
" f_t1 = transform.resize(f_t1, [320, 400, 320])\r\n", |
|
|
122 |
" f_t2 = transform.resize(f_t2, [320, 400, 320])\r\n", |
|
|
123 |
" f_t1ce = transform.resize(f_t1ce, [320, 400, 320])\r\n", |
|
|
124 |
" f_flair = transform.resize(f_flair, [320, 400, 320])\r\n", |
|
|
125 |
" f_seg = transform.resize(f_seg, [320, 400, 320])\r\n", |
|
|
126 |
" return [torch.tensor(np.stack([f_t1, f_t1ce, f_t2, f_flair])), torch.tensor(f_seg)]" |
|
|
127 |
], |
|
|
128 |
"execution_count": null, |
|
|
129 |
"outputs": [] |
|
|
130 |
}, |
|
|
131 |
{ |
|
|
132 |
"cell_type": "code", |
|
|
133 |
"metadata": { |
|
|
134 |
"id": "KYygUfG30MmQ" |
|
|
135 |
}, |
|
|
136 |
"source": [ |
|
|
137 |
"#Very janky way of seperating into train and validation \r\n", |
|
|
138 |
"eval_dataset.select(lambda x : int(list(x.items())[0][1].split('/')[-1][17]) == 3)\r\n", |
|
|
139 |
"train_dataset.select(lambda x : int(list(x.items())[0][1].split('/')[-1][17]) < 3)\r\n", |
|
|
140 |
"\r\n", |
|
|
141 |
"train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=5,collate_fn=col_img)\r\n", |
|
|
142 |
"eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=5,collate_fn=col_img)" |
|
|
143 |
], |
|
|
144 |
"execution_count": null, |
|
|
145 |
"outputs": [] |
|
|
146 |
} |
|
|
147 |
] |
|
|
148 |
} |