833 lines (833 with data), 37.7 kB
{
"cells": [
{
"cell_type": "markdown",
"source": [
"<a href=\"https://colab.research.google.com/github/neuroneural/brainchop/blob/master/py2tfjs/Convert_Trained_Model_To_TFJS.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
],
"metadata": {
"id": "A1c2bimsUbI2"
}
},
{
"cell_type": "markdown",
"source": [
"#Pytorch Model Conversion to tfjs"
],
"metadata": {
"id": "ef8frdPU3AWh"
}
},
{
"cell_type": "markdown",
"source": [
"In this example, you will find a simple steps on how to convert the trained **MeshNet** model to tfjs model that can be used with [**Brainchop**](https://neuroneural.github.io/brainchop/). Given a Gray Matter White Matter (GWM) segmentation model that segmenting the brain into three different regions,uccessful conversion to tfjs will result in two main files, the model.json file, and the weights bin file:\n",
"\n",
" -The model.json file consists of model topology and weights manifest.\n",
" -The binary weights file (i.e. *.bin) consists of the concatenated weight values.\n",
"\n",
"\n",
"\n",
"This conversoin pipeline example is part of the [**Brainchop**](https://neuroneural.github.io/brainchop/) project, where the basic MeshNet model is trained using **PyTorch**, and the resulting model converted to the **Tensorflow.js** (tfjs) model to be used with Brainchop.\n",
"\n",
"---\n",
"\n",
"**Authors:** [Mohamed Masoud](https://github.com/Mmasoud1), and [Sergey Plis](https://github.com/sergeyplis)\n",
"\n"
],
"metadata": {
"id": "tKyWAoiv3mN3"
}
},
{
"cell_type": "markdown",
"source": [
"."
],
"metadata": {
"id": "M9240xV23VUn"
}
},
{
"cell_type": "markdown",
"source": [
"### Fetch required python files from brainchop repository [**folder**](https://github.com/neuroneural/brainchop/tree/master/py2tfjs/conversion_example)"
],
"metadata": {
"id": "jTqL2GwjXszv"
}
},
{
"cell_type": "markdown",
"source": [
"We'll be calling three python scripts as libraries that need download from brainchop.\n"
],
"metadata": {
"id": "JHsePqAZ24A-"
}
},
{
"cell_type": "code",
"source": [
"!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/blendbatchnorm.py\"}\n",
"!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet.py\"}\n",
"!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet2tfjs.py\"}\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vekwxqlZUf0P",
"outputId": "39660321-eedb-4cdc-a5fa-c999e2482a3e"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2024-03-28 09:08:58-- https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/blendbatchnorm.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2254 (2.2K) [text/plain]\n",
"Saving to: ‘blendbatchnorm.py’\n",
"\n",
"\rblendbatchnorm.py 0%[ ] 0 --.-KB/s \rblendbatchnorm.py 100%[===================>] 2.20K --.-KB/s in 0s \n",
"\n",
"2024-03-28 09:08:58 (39.2 MB/s) - ‘blendbatchnorm.py’ saved [2254/2254]\n",
"\n",
"--2024-03-28 09:08:58-- https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 4711 (4.6K) [text/plain]\n",
"Saving to: ‘meshnet.py’\n",
"\n",
"meshnet.py 100%[===================>] 4.60K --.-KB/s in 0s \n",
"\n",
"2024-03-28 09:08:58 (73.0 MB/s) - ‘meshnet.py’ saved [4711/4711]\n",
"\n",
"--2024-03-28 09:08:58-- https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet2tfjs.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 8711 (8.5K) [text/plain]\n",
"Saving to: ‘meshnet2tfjs.py’\n",
"\n",
"meshnet2tfjs.py 100%[===================>] 8.51K --.-KB/s in 0s \n",
"\n",
"2024-03-28 09:08:58 (103 MB/s) - ‘meshnet2tfjs.py’ saved [8711/8711]\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"We'll be calling the saved trained Pytorch model to be converted to TFJS."
],
"metadata": {
"id": "tWea0YpISfBu"
}
},
{
"cell_type": "code",
"source": [
"!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/modelAE.json\"}\n",
"!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/model.pth\"}\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T_PKTsVHQt82",
"outputId": "c842f2f5-ba30-4d97-c356-826f41e69136"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2024-03-28 09:09:02-- https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/modelAE.json\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3044 (3.0K) [text/plain]\n",
"Saving to: ‘modelAE.json’\n",
"\n",
"\rmodelAE.json 0%[ ] 0 --.-KB/s \rmodelAE.json 100%[===================>] 2.97K --.-KB/s in 0s \n",
"\n",
"2024-03-28 09:09:02 (45.2 MB/s) - ‘modelAE.json’ saved [3044/3044]\n",
"\n",
"--2024-03-28 09:09:02-- https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/model.pth\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 484293 (473K) [application/octet-stream]\n",
"Saving to: ‘model.pth’\n",
"\n",
"model.pth 100%[===================>] 472.94K --.-KB/s in 0.02s \n",
"\n",
"2024-03-28 09:09:02 (20.3 MB/s) - ‘model.pth’ saved [484293/484293]\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from blendbatchnorm import fuse_bn_recursively\n",
"from meshnet2tfjs import meshnet2tfjs\n",
"from meshnet import (\n",
" MeshNet,\n",
" enMesh_checkpoint,\n",
")\n",
"\n",
"device_name = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"device = torch.device(device_name)\n",
"\n",
"\n",
"# Normalization\n",
"def preprocess_image(img, qmin=0.01, qmax=0.99):\n",
" \"\"\"Unit interval preprocessing\"\"\"\n",
" img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))\n",
" return img"
],
"metadata": {
"id": "f6nsdOgpY1sS"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# specify how many classes does the model predict\n",
"n_classes = 3\n",
"# specify the architecture\n",
"config_file = \"modelAE.json\"\n",
"# how many channels does the saved model have\n",
"model_channels = 15\n",
"# path to the saved model\n",
"model_path = \"model.pth\"\n",
"# tfjs model output directory with colab\n",
"tfjs_model_dir = \"model_tfjs\"\n",
"\n",
"meshnet_model = enMesh_checkpoint(\n",
" in_channels=1,\n",
" n_classes=n_classes,\n",
" channels=model_channels,\n",
" config_file=config_file,\n",
")\n",
"\n",
"checkpoint = torch.load(model_path)\n",
"meshnet_model.load_state_dict(checkpoint)"
],
"metadata": {
"id": "haVQV-84Y83f",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4928de22-a3d2-41f8-e095-be2bc6412a15"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"meshnet_model.eval()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g0dkzzwhaWTn",
"outputId": "49ee826b-a774-437f-e79c-91bc81241233"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"enMesh_checkpoint(\n",
" (model): Sequential(\n",
" (0): Sequential(\n",
" (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (2): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (4): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (5): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (7): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (8): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (9): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (10): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (11): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (12): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (13): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (14): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (15): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (16): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (17): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (18): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (19): Conv3d(15, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"meshnet_model.to(device)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gl0psSCoaXPp",
"outputId": "6225d17e-b851-4bf9-d269-540efd5c1a72"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"enMesh_checkpoint(\n",
" (model): Sequential(\n",
" (0): Sequential(\n",
" (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (2): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (4): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (5): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (7): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (8): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (9): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (10): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (11): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (12): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (13): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (14): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (15): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (16): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (17): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (18): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
" (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (19): Conv3d(15, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
" )\n",
")"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"source": [
"This function takes a sequential block and fuses the batch normalization with convolution"
],
"metadata": {
"id": "ZoWpbvPWTwcf"
}
},
{
"cell_type": "code",
"source": [
"mnm = fuse_bn_recursively(meshnet_model)\n",
"\n",
"del meshnet_model\n",
"mnm.model.eval()"
],
"metadata": {
"id": "bgempY9maeLF",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "05252348-06b3-4516-e2c6-61131d526f21"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Sequential(\n",
" (0): Sequential(\n",
" (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (1): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (2): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (3): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (4): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (5): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (6): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (7): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (8): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (9): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (10): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (11): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (12): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (13): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (14): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (15): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (16): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (17): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (18): Sequential(\n",
" (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
" (1): ELU(alpha=1.0, inplace=True)\n",
" )\n",
" (19): Conv3d(15, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
")"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "markdown",
"source": [
"Convert MeshNet model to TensorFlow.js"
],
"metadata": {
"id": "6mSundvNUsRC"
}
},
{
"cell_type": "code",
"source": [
"meshnet2tfjs(mnm, tfjs_model_dir)"
],
"metadata": {
"id": "ywZ4t1jaarKH"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!ls"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "egMsj_0ReTEg",
"outputId": "50a1e41a-0467-4a82-a59c-e2778606fb43"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"bcmodel.zip\t meshnet2tfjs.py\tmeshnet.py.1\tmodel.pth __pycache__\n",
"blendbatchnorm.py meshnet2tfjs.py.1\tmodelAE.json\tmodel.pth.1 sample_data\n",
"blendbatchnorm.py.1 meshnet.py\t\tmodelAE.json.1\tmodel_tfjs\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Save converted files to zip file"
],
"metadata": {
"id": "GAEZ9-NfU-ad"
}
},
{
"cell_type": "code",
"source": [
"!zip -r \"/content/bcmodel.zip\" \"model_tfjs\""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bfp6oVRpfYF3",
"outputId": "506d7e67-9a0c-4ece-c346-92550ff757f1"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"updating: model_tfjs/ (stored 0%)\n",
"updating: model_tfjs/model.json (deflated 97%)\n",
"updating: model_tfjs/model.bin (deflated 6%)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Download TensorFlow model"
],
"metadata": {
"id": "1u6OgFCyVEyb"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import files\n",
"files.download(\"/content/bcmodel.zip\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"id": "3NgMj3Hxf4iO",
"outputId": "6bb5b6cb-924e-4c94-9f71-53e60f11b35a"
},
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"download(\"download_ffa47f24-3854-4ba2-b4e1-e343813d6927\", \"bcmodel.zip\", 415840)"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# **Final notes**"
],
"metadata": {
"id": "Q2JEHsWhZvgk"
}
},
{
"cell_type": "markdown",
"source": [
"This tutorial aims to provide a simple example of how to convert segmentation model from python pipeline to TensorFlow.js files (i.e. model.json and model.bin). "
],
"metadata": {
"id": "q_yFRL5BZ22b"
}
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "QzA8CxoVGi0t"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "T4"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}