--- a
+++ b/save_middle_slices_as_images.ipynb
@@ -0,0 +1,258 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Save middle slices of scans as images\n",
+    "Most image models expect a three-channel input image. Will take the middle three grayscale slices, and dilated middle three slices centered on the middle, and save them as three-channel images."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "from PIL import Image\n",
+    "import os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\u001b[01;34m..\u001b[00m\n",
+      "├── \u001b[01;34mdata\u001b[00m\n",
+      "│   ├── \u001b[01;34m__MACOSX\u001b[00m\n",
+      "│   │   └── \u001b[01;34mMRNet-v1.0\u001b[00m\n",
+      "│   │       ├── \u001b[01;34mtrain\u001b[00m\n",
+      "│   │       │   ├── \u001b[01;34maxial\u001b[00m\n",
+      "│   │       │   ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│   │       │   └── \u001b[01;34msagittal\u001b[00m\n",
+      "│   │       └── \u001b[01;34mvalid\u001b[00m\n",
+      "│   ├── \u001b[01;34mmid1\u001b[00m\n",
+      "│   │   ├── \u001b[01;34mtrain\u001b[00m\n",
+      "│   │   │   ├── \u001b[01;34maxial\u001b[00m\n",
+      "│   │   │   ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│   │   │   └── \u001b[01;34msagittal\u001b[00m\n",
+      "│   │   └── \u001b[01;34mvalid\u001b[00m\n",
+      "│   │       ├── \u001b[01;34maxial\u001b[00m\n",
+      "│   │       ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│   │       └── \u001b[01;34msagittal\u001b[00m\n",
+      "│   ├── \u001b[01;34mMRNet-small\u001b[00m\n",
+      "│   │   ├── \u001b[01;34mtrain\u001b[00m\n",
+      "│   │   │   ├── \u001b[01;34maxial\u001b[00m\n",
+      "│   │   │   ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│   │   │   └── \u001b[01;34msagittal\u001b[00m\n",
+      "│   │   └── \u001b[01;34mvalid\u001b[00m\n",
+      "│   │       ├── \u001b[01;34maxial\u001b[00m\n",
+      "│   │       ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│   │       └── \u001b[01;34msagittal\u001b[00m\n",
+      "│   └── \u001b[01;34mMRNet-v1.0\u001b[00m\n",
+      "│       ├── \u001b[01;34mtrain\u001b[00m\n",
+      "│       │   ├── \u001b[01;34maxial\u001b[00m\n",
+      "│       │   ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│       │   └── \u001b[01;34msagittal\u001b[00m\n",
+      "│       └── \u001b[01;34mvalid\u001b[00m\n",
+      "│           ├── \u001b[01;34maxial\u001b[00m\n",
+      "│           ├── \u001b[01;34mcoronal\u001b[00m\n",
+      "│           └── \u001b[01;34msagittal\u001b[00m\n",
+      "└── \u001b[01;34mmrnet-fastai\u001b[00m\n",
+      "    └── \u001b[01;34mexp\u001b[00m\n",
+      "\n",
+      "37 directories\n"
+     ]
+    }
+   ],
+   "source": [
+    "! tree -d .."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Copy data to another directory, for example `/mid3d0`, for middle 3 slices, 0 dilation.\n",
+    "Do this for d in (0,1,2) and perhaps greater than 2."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!cp -r ../data/MRNet-v1.0 ../data/mid3d0"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Recurse through /train and /valid directories, load the scan file (.npy), select middle 3 slices, possibly with gaps of size dilation between them, and save as 3-channel images. Discard the copied files."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "-----------------------\n",
+      "Creating png files centered on middle slices with dilations of size 2\n",
+      "Processing 1130 .npy files in ../data/mid3d2/train/axial\n",
+      "===============================\n",
+      "Processing 1130 .npy files in ../data/mid3d2/train/coronal\n",
+      "===============================\n",
+      "Processing 1130 .npy files in ../data/mid3d2/train/sagittal\n",
+      "===============================\n",
+      "Processing 120 .npy files in ../data/mid3d2/valid/axial\n",
+      "===============================\n",
+      "Processing 120 .npy files in ../data/mid3d2/valid/coronal\n",
+      "===============================\n",
+      "Processing 120 .npy files in ../data/mid3d2/valid/sagittal\n",
+      "===============================\n"
+     ]
+    }
+   ],
+   "source": [
+    "dilations = [0,1,2]\n",
+    "for d in dilations:\n",
+    "    print('-----------------------')\n",
+    "    print('Creating png files centered on middle slices with dilations of size {}'.format(d))\n",
+    "    for ds in ('train','valid'):\n",
+    "        for p in ('axial','coronal','sagittal'):\n",
+    "            dirpath = '../data/mid3d{}/{}/{}'.format(d, ds, p)\n",
+    "            npy_files = [f for f in os.listdir(dirpath) if f[-4:]=='.npy']\n",
+    "            print('Processing {} .npy files in {}'.format(len(npy_files), dirpath))\n",
+    "\n",
+    "            for npyf in npy_files:\n",
+    "                npyfilepath = dirpath + '/' + npyf\n",
+    "                #print('Converting {}'.format(npyfilepath))\n",
+    "                # load the array\n",
+    "                scanarray = np.load(npyfilepath)\n",
+    "                #print('Number of slices for this scan: {}'.format(scanarray.shape[0]))\n",
+    "                # get the middle slice index\n",
+    "                midslice = scanarray.shape[0]//2\n",
+    "                # get indices for three slices centered on midslice, skipping dilation d images between the slices\n",
+    "                slice_indices = range(midslice-(1+d),midslice+(1+d+1),1+d)\n",
+    "                slices3  = scanarray[slice_indices,:,:]\n",
+    "                # put channels last for PIL\n",
+    "                slices3  = np.moveaxis(slices3, 0, -1)\n",
+    "                # convert to rgb\n",
+    "                im = Image.fromarray(slices3)\n",
+    "                # save file\n",
+    "                # replace .npy with .png\n",
+    "                pngf = npyf[:-4] + '.png'\n",
+    "                pngfilepath = dirpath + '/' + pngf\n",
+    "                im.save(pngfilepath)\n",
+    "                # remove corresponding .npy file\n",
+    "                os.remove(npyfilepath)\n",
+    "\n",
+    "            print('===============================')        "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Take just middle slice from each scan"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Found 1130 .npy files in ../data/mid1/train/axial\n",
+      "===============================\n",
+      "===============================\n",
+      "Found 1130 .npy files in ../data/mid1/train/coronal\n",
+      "===============================\n",
+      "===============================\n",
+      "Found 1130 .npy files in ../data/mid1/train/sagittal\n",
+      "===============================\n",
+      "===============================\n",
+      "Found 120 .npy files in ../data/mid1/valid/axial\n",
+      "===============================\n",
+      "===============================\n",
+      "Found 120 .npy files in ../data/mid1/valid/coronal\n",
+      "===============================\n",
+      "===============================\n",
+      "Found 120 .npy files in ../data/mid1/valid/sagittal\n",
+      "===============================\n",
+      "===============================\n"
+     ]
+    }
+   ],
+   "source": [
+    "dilation = ''\n",
+    "for d in ('train','valid'):\n",
+    "    for p in ('axial','coronal','sagittal'):\n",
+    "        dirpath = '../data/mid1{}/{}/{}'.format(dilation, d, p)\n",
+    "        npy_files = [f for f in os.listdir(dirpath) if f[-4:]=='.npy']\n",
+    "        print('Processing {} .npy files in {}'.format(len(npy_files), dirpath))\n",
+    "\n",
+    "        for f in npy_files:\n",
+    "            filepath = dirpath + '/' + f\n",
+    "            #print('Converting {}'.format(filepath))\n",
+    "            # load the array\n",
+    "            scanarray = np.load(filepath)\n",
+    "            # calculate interpolation factor(s)\n",
+    "            #print('Number of slices for this scan: {}'.format(scanarray.shape[0]))\n",
+    "            midslice = scanarray.shape[0]//2\n",
+    "            im = Image.fromarray(scanarray[midslice,:,:], 'L').convert('RGB')\n",
+    "            # replace .npy with .png\n",
+    "            newf = f[:-4] + '.png'\n",
+    "            newfilepath = dirpath + '/' + newf\n",
+    "            im.save(newfilepath)\n",
+    "            # remove corresponding .npy file\n",
+    "            os.remove(filepath)\n",
+    "        \n",
+    "        print('===============================')        "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}