[b86468]: / v3 / py2tfjs / Convert_Trained_Model_To_TFJS.ipynb

Download this file

833 lines (833 with data), 37.7 kB

{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "<a href=\"https://colab.research.google.com/github/neuroneural/brainchop/blob/master/py2tfjs/Convert_Trained_Model_To_TFJS.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ],
      "metadata": {
        "id": "A1c2bimsUbI2"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#Pytorch Model Conversion to tfjs"
      ],
      "metadata": {
        "id": "ef8frdPU3AWh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "In this example, you will find a simple steps on how to convert the  trained  **MeshNet** model to tfjs model that can be used with [**Brainchop**](https://neuroneural.github.io/brainchop/). Given a Gray Matter White Matter (GWM) segmentation model that segmenting the brain into three different regions,uccessful conversion to tfjs will result in two main files, the model.json file, and the weights bin file:\n",
        "\n",
        "    -The model.json file consists of model topology and weights manifest.\n",
        "    -The binary weights file (i.e. *.bin) consists of the concatenated weight values.\n",
        "\n",
        "\n",
        "\n",
        "This conversoin pipeline example is part of the [**Brainchop**](https://neuroneural.github.io/brainchop/)  project, where the basic MeshNet model is trained using **PyTorch**, and the resulting model converted to the **Tensorflow.js** (tfjs) model to be used with Brainchop.\n",
        "\n",
        "---\n",
        "\n",
        "**Authors:** [Mohamed Masoud](https://github.com/Mmasoud1), and [Sergey Plis](https://github.com/sergeyplis)\n",
        "\n"
      ],
      "metadata": {
        "id": "tKyWAoiv3mN3"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "."
      ],
      "metadata": {
        "id": "M9240xV23VUn"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Fetch required python files from brainchop repository [**folder**](https://github.com/neuroneural/brainchop/tree/master/py2tfjs/conversion_example)"
      ],
      "metadata": {
        "id": "jTqL2GwjXszv"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We'll be calling three python scripts as libraries that need download from brainchop.\n"
      ],
      "metadata": {
        "id": "JHsePqAZ24A-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/blendbatchnorm.py\"}\n",
        "!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet.py\"}\n",
        "!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet2tfjs.py\"}\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vekwxqlZUf0P",
        "outputId": "39660321-eedb-4cdc-a5fa-c999e2482a3e"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2024-03-28 09:08:58--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/blendbatchnorm.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 2254 (2.2K) [text/plain]\n",
            "Saving to: ‘blendbatchnorm.py’\n",
            "\n",
            "\rblendbatchnorm.py     0%[                    ]       0  --.-KB/s               \rblendbatchnorm.py   100%[===================>]   2.20K  --.-KB/s    in 0s      \n",
            "\n",
            "2024-03-28 09:08:58 (39.2 MB/s) - ‘blendbatchnorm.py’ saved [2254/2254]\n",
            "\n",
            "--2024-03-28 09:08:58--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 4711 (4.6K) [text/plain]\n",
            "Saving to: ‘meshnet.py’\n",
            "\n",
            "meshnet.py          100%[===================>]   4.60K  --.-KB/s    in 0s      \n",
            "\n",
            "2024-03-28 09:08:58 (73.0 MB/s) - ‘meshnet.py’ saved [4711/4711]\n",
            "\n",
            "--2024-03-28 09:08:58--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/meshnet2tfjs.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 8711 (8.5K) [text/plain]\n",
            "Saving to: ‘meshnet2tfjs.py’\n",
            "\n",
            "meshnet2tfjs.py     100%[===================>]   8.51K  --.-KB/s    in 0s      \n",
            "\n",
            "2024-03-28 09:08:58 (103 MB/s) - ‘meshnet2tfjs.py’ saved [8711/8711]\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We'll be calling the saved trained Pytorch model to be converted to TFJS."
      ],
      "metadata": {
        "id": "tWea0YpISfBu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/modelAE.json\"}\n",
        "!wget --no-cache --backups=1 {f\"https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/model.pth\"}\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "T_PKTsVHQt82",
        "outputId": "c842f2f5-ba30-4d97-c356-826f41e69136"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2024-03-28 09:09:02--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/modelAE.json\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 3044 (3.0K) [text/plain]\n",
            "Saving to: ‘modelAE.json’\n",
            "\n",
            "\rmodelAE.json          0%[                    ]       0  --.-KB/s               \rmodelAE.json        100%[===================>]   2.97K  --.-KB/s    in 0s      \n",
            "\n",
            "2024-03-28 09:09:02 (45.2 MB/s) - ‘modelAE.json’ saved [3044/3044]\n",
            "\n",
            "--2024-03-28 09:09:02--  https://raw.githubusercontent.com/neuroneural/brainchop/master/py2tfjs/conversion_example/model.pth\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 484293 (473K) [application/octet-stream]\n",
            "Saving to: ‘model.pth’\n",
            "\n",
            "model.pth           100%[===================>] 472.94K  --.-KB/s    in 0.02s   \n",
            "\n",
            "2024-03-28 09:09:02 (20.3 MB/s) - ‘model.pth’ saved [484293/484293]\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from blendbatchnorm import fuse_bn_recursively\n",
        "from meshnet2tfjs import meshnet2tfjs\n",
        "from meshnet import (\n",
        "    MeshNet,\n",
        "    enMesh_checkpoint,\n",
        ")\n",
        "\n",
        "device_name = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
        "device = torch.device(device_name)\n",
        "\n",
        "\n",
        "# Normalization\n",
        "def preprocess_image(img, qmin=0.01, qmax=0.99):\n",
        "    \"\"\"Unit interval preprocessing\"\"\"\n",
        "    img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))\n",
        "    return img"
      ],
      "metadata": {
        "id": "f6nsdOgpY1sS"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# specify how many classes does the model predict\n",
        "n_classes = 3\n",
        "# specify the architecture\n",
        "config_file = \"modelAE.json\"\n",
        "# how many channels does the saved model have\n",
        "model_channels = 15\n",
        "# path to the saved model\n",
        "model_path = \"model.pth\"\n",
        "# tfjs model output directory with colab\n",
        "tfjs_model_dir = \"model_tfjs\"\n",
        "\n",
        "meshnet_model = enMesh_checkpoint(\n",
        "    in_channels=1,\n",
        "    n_classes=n_classes,\n",
        "    channels=model_channels,\n",
        "    config_file=config_file,\n",
        ")\n",
        "\n",
        "checkpoint = torch.load(model_path)\n",
        "meshnet_model.load_state_dict(checkpoint)"
      ],
      "metadata": {
        "id": "haVQV-84Y83f",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4928de22-a3d2-41f8-e095-be2bc6412a15"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "meshnet_model.eval()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "g0dkzzwhaWTn",
        "outputId": "49ee826b-a774-437f-e79c-91bc81241233"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "enMesh_checkpoint(\n",
              "  (model): Sequential(\n",
              "    (0): Sequential(\n",
              "      (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (2): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (3): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (4): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (5): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (6): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (7): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (8): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (9): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (10): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (11): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (12): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (13): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (14): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (15): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (16): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (17): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (18): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (19): Conv3d(15, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "meshnet_model.to(device)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gl0psSCoaXPp",
        "outputId": "6225d17e-b851-4bf9-d269-540efd5c1a72"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "enMesh_checkpoint(\n",
              "  (model): Sequential(\n",
              "    (0): Sequential(\n",
              "      (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (1): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (2): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (3): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (4): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (5): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (6): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (7): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (8): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (9): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (10): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (11): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (12): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (13): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (14): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (15): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (16): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (17): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (18): Sequential(\n",
              "      (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
              "      (1): BatchNorm3d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (2): ELU(alpha=1.0, inplace=True)\n",
              "    )\n",
              "    (19): Conv3d(15, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This function takes a sequential block and fuses the batch normalization with convolution"
      ],
      "metadata": {
        "id": "ZoWpbvPWTwcf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "mnm = fuse_bn_recursively(meshnet_model)\n",
        "\n",
        "del meshnet_model\n",
        "mnm.model.eval()"
      ],
      "metadata": {
        "id": "bgempY9maeLF",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "05252348-06b3-4516-e2c6-61131d526f21"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Sequential(\n",
              "  (0): Sequential(\n",
              "    (0): Conv3d(1, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (1): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (2): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (3): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (4): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (5): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (6): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (7): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (8): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (9): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (10): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(12, 12, 12), dilation=(12, 12, 12))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (11): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (12): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(7, 7, 7), dilation=(7, 7, 7))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (13): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(6, 6, 6), dilation=(6, 6, 6))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (14): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(5, 5, 5), dilation=(5, 5, 5))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (15): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (16): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(3, 3, 3), dilation=(3, 3, 3))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (17): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (18): Sequential(\n",
              "    (0): Conv3d(15, 15, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
              "    (1): ELU(alpha=1.0, inplace=True)\n",
              "  )\n",
              "  (19): Conv3d(15, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Convert MeshNet model to TensorFlow.js"
      ],
      "metadata": {
        "id": "6mSundvNUsRC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "meshnet2tfjs(mnm, tfjs_model_dir)"
      ],
      "metadata": {
        "id": "ywZ4t1jaarKH"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!ls"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "egMsj_0ReTEg",
        "outputId": "50a1e41a-0467-4a82-a59c-e2778606fb43"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "bcmodel.zip\t     meshnet2tfjs.py\tmeshnet.py.1\tmodel.pth    __pycache__\n",
            "blendbatchnorm.py    meshnet2tfjs.py.1\tmodelAE.json\tmodel.pth.1  sample_data\n",
            "blendbatchnorm.py.1  meshnet.py\t\tmodelAE.json.1\tmodel_tfjs\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Save converted files to zip file"
      ],
      "metadata": {
        "id": "GAEZ9-NfU-ad"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!zip -r \"/content/bcmodel.zip\" \"model_tfjs\""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Bfp6oVRpfYF3",
        "outputId": "506d7e67-9a0c-4ece-c346-92550ff757f1"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "updating: model_tfjs/ (stored 0%)\n",
            "updating: model_tfjs/model.json (deflated 97%)\n",
            "updating: model_tfjs/model.bin (deflated 6%)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download TensorFlow model"
      ],
      "metadata": {
        "id": "1u6OgFCyVEyb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import files\n",
        "files.download(\"/content/bcmodel.zip\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "3NgMj3Hxf4iO",
        "outputId": "6bb5b6cb-924e-4c94-9f71-53e60f11b35a"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_ffa47f24-3854-4ba2-b4e1-e343813d6927\", \"bcmodel.zip\", 415840)"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# **Final notes**"
      ],
      "metadata": {
        "id": "Q2JEHsWhZvgk"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This tutorial aims to provide a simple example of how to convert segmentation model from python pipeline to TensorFlow.js files (i.e. model.json and model.bin).  "
      ],
      "metadata": {
        "id": "q_yFRL5BZ22b"
      }
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "QzA8CxoVGi0t"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "T4"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}