<a href="https://colab.research.google.com/github/neuroneural/brainchop/blob/master/py2tfjs/MeshNet_Training_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this tutorial, you will find a simple tutorial on how to train the **MeshNet** model for MRI brain Gray Matter White Matter (GWM) segmentation. This task involves segmenting the brain into three different regions. The model will be trained on sample volumes from the **Mindboggle 101** brain MRI scans dataset for the multiclass 3D segmentation task.


This training pipeline example is part of the [**Brainchop**](https://neuroneural.github.io/brainchop/)  project, where the basic MeshNet model is trained using **PyTorch**, and the resulting model can be converted to the **Tensorflow.js** (tfjs) model to be used with Brainchop.

For more information about the whole conversion process, please refer to the Repo [Wiki](https://github.com/neuroneural/brainchop/wiki).

---

This tutorial developed by [Pratyush Reddy](pratyushrg@gmail.com
),  revised by [Mohamed Masoud](mohamedemory@gmail.com), and [Sergey Plis](s.m.plis@gmail.com)



#Imports

* Installing nibabel and nilearn packages for reading 3d Brain MRI scans.
* Importing essential libraries

In [None]:
!pip install nibabel nilearn

Collecting nilearn
  Downloading nilearn-0.10.2-py3-none-any.whl (10.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: nilearn
Successfully installed nilearn-0.10.2


In [None]:
import numpy as np
import pandas as pd
import nibabel as nib
import ipywidgets as widgets
from nilearn import plotting
import matplotlib.pyplot as plt
from IPython.display import display
from collections import OrderedDict
import torch
import os
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential
from torch.utils.data import DataLoader

# Download required dataset

We'll be using samples from Mindboggle 101 brain MRI scans dataset for a multiclass 3d segmentation task.
* Contents:
  1. BrainDatasets comprise of T1 scans + the prepared limited labels.
  2. Training(10 pairs of images,labels), Validation(2 pairs), Inference(3 pairs) datasets

In [None]:
!wget https://meshnet-pr-dataset.s3.amazonaws.com/data-1-10.zip
!unzip data-1-10.zip
!rm data-1-10.zip

--2023-11-27 06:30:53--  https://meshnet-pr-dataset.s3.amazonaws.com/data-1-10.zip
Resolving meshnet-pr-dataset.s3.amazonaws.com (meshnet-pr-dataset.s3.amazonaws.com)... 54.231.200.33, 52.217.196.65, 52.216.24.68, ...
Connecting to meshnet-pr-dataset.s3.amazonaws.com (meshnet-pr-dataset.s3.amazonaws.com)|54.231.200.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 128402177 (122M) [application/zip]
Saving to: ‘data-1-10.zip’


2023-11-27 06:30:56 (34.8 MB/s) - ‘data-1-10.zip’ saved [128402177/128402177]

Archive:  data-1-10.zip
  inflating: data/coords_generator.py  
  inflating: data/dataset_infer.csv  
  inflating: data/dataset_train.csv  
  inflating: data/model.py           
  inflating: data/brain_dataset(1).py  
  inflating: data/reader.py          
  inflating: data/brain_dataset.py   
  inflating: data/dataset_valid.csv  
  inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-18/t1weighted.nii.gz  
  inflating: data/Mindboggle_101/MMRR-21_volu

# Generate graywhite and anatomic labels from existing label files

In [None]:
pwd

'/content'

This function takes an input label array and produces a binary mask that distinguishes white matter and gray matter regions using specific integer codes. The conversion to comply with FreeSurfer label numbersing convention and their ColorLUT:

https://surfer.nmr.mgh.harvard.edu/fswiki/FsTutorial/AnatomicalROI/FreeSurferColorLUT

In [None]:
def labels2graywhite(label):
 white_code = [2, 41, 7, 16, 46] + [*range(251, 256)]
 gray_code = (
     [*range(1001, 1004)]
     + [*range(1005, 1036)]
     + [*range(2001, 2004)]
     + [*range(2005, 2036)]
     + [*range(8, 14)]
     + [*range(17, 21)]
     + [*range(26, 29)]
     + [*range(47, 56)]
     + [*range(58, 61)]
 )
 white = np.isin(label, white_code)
 gray = np.isin(label, gray_code)
 return white.astype(np.uint8) + (2 * gray).astype(np.uint8)

This function takes an input label array and maps the labels to anatomical regions

In [None]:
def labels2anatomic(label):
    source = (
        [0]
        + [i for i in range(1001, 1004)]
        + [i for i in range(1005, 1036)]
        + [i for i in range(2001, 2004)]
        + [i for i in range(2005, 2036)]
        + [10,49,11,50,12,51,13,52,17,53,18,54,26,58,28,60,2,41,4,5,43,44,14,15,24,16,7,46,8,47,251,252,253,254,255,]
        )
    labelmap = {x: idx for idx, x in enumerate(source)}

    @np.vectorize
    def relabel(x):
        y = 0
        if x in labelmap:
            y = labelmap[x]
        return y
    return relabel(label).astype(np.uint8)

This function is for creating new labels names for labels generated from above **labels2anatomic, labels2graywhite** functions

In [None]:
def create_label(label, prefix):
  temp=label.split('/')[:-1]
  temp.append(prefix+label.split('/')[-1])
  return '/'.join(temp)

function to update dataset.CSV files with new GW labels and anatamic labels details

In [None]:
def update_csv(CSV_file):
  data = pd.read_csv(CSV_file)
  data['GWlabels']=np.array([create_label(i, 'GW') for i in data['labels']], dtype=object)
  # data['ANAlabels']=np.array([create_label(i, 'ANA') for i in data['labels']], dtype=object)
  # data.drop(['nii_labels'], inplace=True, axis=1)
  data.to_csv(CSV_file, index=False)

Function to Create GW labes using **labels2graywhite** function

In [None]:
def create_greywhite(CSV_file):
  data = pd.read_csv(CSV_file)
  for label,GWlabel in zip(data.labels,data.GWlabels):
    print('/'.join(GWlabel.split('/')[:-1]),GWlabel.split('/')[-1])
    img_nifti = nib.load(label)
    img = np.array(img_nifti.dataobj)
    ni_img = nib.Nifti1Image(labels2graywhite(img), affine=np.eye(4))
    nib.save(ni_img, os.path.join('/'.join(GWlabel.split('/')[:-1]), GWlabel.split('/')[-1]))

Function to Create anatomic labes using **labels2anatomic** function

In [None]:
def create_ANA(CSV_file):
  data = pd.read_csv(CSV_file)
  for label,ANAlabel in zip(data.labels,data.ANAlabels):
    print('/'.join(ANAlabel.split('/')[:-1]),ANAlabel.split('/')[-1])
    img_nifti = nib.load(label)
    img = np.array(img_nifti.dataobj)
    ni_img = nib.Nifti1Image(labels2anatomic(img), affine=np.eye(4))
    nib.save(ni_img, os.path.join('/'.join(ANAlabel.split('/')[:-1]), ANAlabel.split('/')[-1]))

Update dataset.csv's [Train,Infer,Valid) with generated GW and anatomic labels using **update_csv,create_ANA,create_greywhite**

In [None]:
update_csv('./data/dataset_train.csv')
create_greywhite('./data/dataset_train.csv')
# create_ANA('./data/dataset_train.csv')
update_csv('./data/dataset_infer.csv')
create_greywhite('./data/dataset_infer.csv')
# create_ANA('./data/dataset_infer.csv')
update_csv('./data/dataset_valid.csv')
create_greywhite('./data/dataset_valid.csv')
# create_ANA('./data/dataset_valid.csv')

data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-19 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-18 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-11 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-8 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-15 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/OASIS-TRT-20_volumes/OASIS-TRT-20-13 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-9 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/Extra-18_volumes/HLN-12-12 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/Extra-18_volumes/MMRR-3T7T-2-1 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-5 GWlabels.DKT31.manual+aseg.nii.gz
data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-6 GWlabels.DKT31.

# Plotting Functions

The code creates interactive sliders to visualize 3D image volumes and labels. It uses the **peek_class_new** class to handle the slider creation and display, and **nibabel** to load and prepare the image and label data for visualization. We use this functionality to visualise the predictions from our trained model

In [None]:
# Create x, y, and z coordinate sliders
class peek_class_new:
    def __init__(self,scan_data, label_bool):
        self.label_bool = label_bool
        self.scan_data = scan_data
        if self.label_bool:
          self.x_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[0]-1, value=(self.scan_data.shape[0]-1)/2, description='X')
          self.y_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[1]-1, value=(self.scan_data.shape[1]-1)/2, description='Y')
          self.z_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[2]-1, value=(self.scan_data.shape[2]-1)/2, description='Z')
        else:
          self.x_slider = widgets.IntSlider(min=(self.scan_data.shape[0]-1)*-1, max=0, value=((self.scan_data.shape[0]-1)/2)*-1, description='X')
          self.y_slider = widgets.IntSlider(min=(self.scan_data.shape[1]-1)*-1, max=0, value=((self.scan_data.shape[1]-1)/2)*-1, description='Y')
          self.z_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[0]-1, value=(self.scan_data.shape[2]-1)/2, description='Z')

    def shape(self):
      print(f" {self.scan_data.shape[0]} {self.scan_data.shape[1]} {self.scan_data.shape[2]}")

    def update_slices(self, x, y, z):
      display_plot = plotting.plot_anat(self.scan_data, cut_coords=(x, y, z)).add_markers(marker_coords=[[x, y, z]])
      # Display the plot
      plotting.show()

    def plots(self):
      # Link the sliders to the update function
      widgets.interact(self.update_slices, x=self.x_slider, y=self.y_slider, z=self.z_slider)
      # Display the sliders
      display(self.x_slider, self.y_slider, self.z_slider)

In [None]:
img= nib.load('data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7/t1weighted.nii.gz').get_fdata(dtype=np.float32)
volume_shape = [256, 256, 256]
temp= np.zeros(volume_shape)
temp[: img.shape[0], : img.shape[1], : img.shape[2]] = img
image=temp
# Create a NIfTI image object
nifi_image = nib.Nifti1Image(image, affine=np.eye(4))  # Use identity affine matrix for simplicity

In [None]:
images = peek_class_new(nifi_image,1)
images.plots()

interactive(children=(IntSlider(value=127, description='X', max=255), IntSlider(value=127, description='Y', ma…

IntSlider(value=127, description='X', max=255)

IntSlider(value=127, description='Y', max=255)

IntSlider(value=127, description='Z', max=255)

In [None]:
img= nib.load('data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7/GWlabels.DKT31.manual+aseg.nii.gz').get_fdata(dtype=np.float32)
volume_shape = [256, 256, 256]
temp= np.zeros(volume_shape)
temp[: img.shape[0], : img.shape[1], : img.shape[2]] = img
image=temp
# Create a NIfTI image object
nifi_image = nib.Nifti1Image(image, affine=np.eye(4))  # Use identity affine matrix for simplicity

In [None]:
labels = peek_class_new(nifi_image,1)
labels.plots()

interactive(children=(IntSlider(value=127, description='X', max=255), IntSlider(value=127, description='Y', ma…

IntSlider(value=127, description='X', max=255)

IntSlider(value=127, description='Y', max=255)

IntSlider(value=127, description='Z', max=255)

# Meshnet custom model Implementation

In [None]:
MeshNet_5_ae16 = [
    {"in_channels": -1,"kernel_size": 3,"out_channels": 5,"padding": 1,"stride": 1,"dilation": 1,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 2,"stride": 1,"dilation": 2,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 4,"stride": 1,"dilation": 4,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 8,"stride": 1,"dilation": 8,},
   {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 16,"stride": 1,"dilation": 16,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 8,"stride": 1,"dilation": 8,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 4,"stride": 1,"dilation": 4,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 2,"stride": 1,"dilation": 2,},
    {"in_channels": 5,"kernel_size": 3,"out_channels": 5,"padding": 1,"stride": 1,"dilation": 1,},
    {"in_channels": 5,"kernel_size": 1,"out_channels": -1,"padding": 0,"stride": 1,"dilation": 1,},
]

In [None]:
def ae16channels(channels=5, basearch=MeshNet_5_ae16):
    start = {"out_channels": channels}
    middle = {"in_channels": channels,"out_channels": channels}
    end = {"in_channels": channels}
    modifier = [start] + [middle for _ in range(len(basearch)-2)] + [end]
    newarch = basearch.copy()
    [x.update(y) for x,y in zip(newarch, modifier)]
    return newarch

In [None]:
def conv_w_bn_before_act(dropout_p=0, bnorm=True, gelu=False, *args, **kwargs):
    """Configurable Conv block with Batchnorm and Dropout"""
    sequence = [("conv", nn.Conv3d(*args, **kwargs))]
    if bnorm:
        sequence.append(("bnorm", nn.BatchNorm3d(kwargs["out_channels"])))
    if gelu:
        sequence.append(("gelu", nn.GELU()))
    else:
        sequence.append(("relu", nn.ReLU(inplace=True)))
    sequence.append(("dropout", nn.Dropout3d(dropout_p)))
    layer = nn.Sequential(OrderedDict(sequence))
    return layer

In [None]:
def init_weights(model):
    """Set weights to be xavier normal for all Convs"""
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
            nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain("relu"))
            nn.init.constant_(m.bias, 0.0)

In [None]:
class MeshNet(nn.Module):
    """Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf"""

    def __init__(self, n_channels, n_classes, large=True, bnorm=True, gelu=False, dropout_p=0):
        """Init"""
        if large:
            params  = ae16channels(5)
        else:
            params = MeshNet_5_ae16

        super(MeshNet, self).__init__()
        params[0]["in_channels"] = n_channels
        params[-1]["out_channels"] = n_classes
        layers = [
            conv_w_bn_before_act(dropout_p=dropout_p, bnorm=bnorm, gelu=gelu, **block_kwargs)
            for block_kwargs in params[:-1]
        ]
        layers.append(nn.Conv3d(**params[-1]))
        self.model = nn.Sequential(*layers)
        init_weights(self.model)

    def forward(self, x):
        """Forward pass"""
        x = self.model(x)
        return x

In [None]:
class enMesh_checkpoint(MeshNet):
    def train_forward(self, x):
        y = x
        y.requires_grad_()
        y = checkpoint_sequential(
            self.model, len(self.model), y, preserve_rng_state=False
        )
        return y

    def eval_forward(self, x):
        """Forward pass"""
        self.model.eval()
        with torch.inference_mode():
            x = self.model(x)
        return x

    def forward(self, x):
        if self.training:
            return self.train_forward(x)
        else:
            return self.eval_forward(x)

# Subvolumes generator/Volume reassembler and dataloader funtions

Below class **CubeDivider** provides a convenient way to divide and reassemble tensors, which can be useful for processing large 3D tensors. For more information please refer to the blog : https://medium.com/pytorch/catalyst-neuro-a-3d-brain-segmentation-pipeline-for-mri-b1bb1109276a

In [None]:
import torch
from torch.utils.data import Dataset

class CubeDivider:
    def __init__(self, tensor, num_cubes):
        self.tensor = tensor
        self.num_cubes = num_cubes
        self.sub_cube_size = tensor.shape[0] // num_cubes  # Assuming the tensor is a cube

    def divide_into_sub_cubes(self):
        sub_cubes = []

        for i in range(self.num_cubes):
            for j in range(self.num_cubes):
                for k in range(self.num_cubes):
                    sub_cube = self.tensor[
                        i * self.sub_cube_size: (i + 1) * self.sub_cube_size,
                        j * self.sub_cube_size: (j + 1) * self.sub_cube_size,
                        k * self.sub_cube_size: (k + 1) * self.sub_cube_size
                    ].clone()
                    sub_cubes.append(sub_cube)

        sub_cubes = torch.stack(sub_cubes,0)
        return sub_cubes

    @staticmethod
    def reassemble_sub_cubes(sub_cubes):
        sub_cubes = torch.unbind(sub_cubes, dim=0)
        num_cubes = int(len(sub_cubes) ** (1/3))
        sub_cube_size = sub_cubes[0].shape[0]
        tensor_size = num_cubes * sub_cube_size
        tensor = torch.zeros((tensor_size, tensor_size, tensor_size), dtype=torch.float32)

        for i in range(num_cubes):
            for j in range(num_cubes):
                for k in range(num_cubes):
                    sub_cube = sub_cubes[i * num_cubes**2 + j * num_cubes + k]
                    tensor[
                        i * sub_cube_size: (i + 1) * sub_cube_size,
                        j * sub_cube_size: (j + 1) * sub_cube_size,
                        k * sub_cube_size: (k + 1) * sub_cube_size
                    ] = sub_cube

        return tensor

# Usage:
# Assuming tensor is a 3D PyTorch tensor
tensor = torch.randn(32, 32, 32)  # Example tensor
num_cubes = 2  # Number of sub-cubes

divider = CubeDivider(tensor, num_cubes)

# Divide the cube tensor into sub-cubes
sub_cubes = divider.divide_into_sub_cubes()

# Reassemble the sub-cubes to create the original cube tensor
reconstructed_tensor = CubeDivider.reassemble_sub_cubes(sub_cubes)

print(reconstructed_tensor.shape)  # Should be the same as the original tensor shape


torch.Size([32, 32, 32])


Below class **DataLoaderClass** that loads and processes data from a CSV file using PyTorch's DataLoader.

1. The dataloader method reads the CSV file, preprocesses the images and labels, and creates a DataLoader object for the processed data.
2. The data is divided into sub-cubes using the **CubeDivider** class.
3. The labels are converted into a one-hot encoding representation.

In [None]:
class DataLoaderClass:
  def __init__(self,csv_file, coor_factor, batch_size):
    self.csv_file=csv_file
    self.coor_factor=coor_factor
    self.batch_size=batch_size

  def dataloader(self):
    data = pd.read_csv(self.csv_file)
    volume_shape = [256, 256, 256]
    images =()
    labels=()
    for image,label in zip(data['images'],data['GWlabels']):

      img = nib.load('./'+image)
      img = img.get_fdata()
      temp= np.zeros(volume_shape)
      temp[: img.shape[0], : img.shape[1], : img.shape[2]] = img
      temp = np.array(temp)
      image_data = (temp - temp.mean()) / temp.std()
      sub_temp = CubeDivider(torch.tensor(image_data),self.coor_factor)
      images = images+(sub_temp.divide_into_sub_cubes(),)

      lab = nib.load('./'+label)
      lab = lab.get_fdata()
      temp= np.zeros(volume_shape)
      temp[: lab.shape[0], : lab.shape[1], : lab.shape[2]] = lab
      temp = np.array(temp)
      sub_temp = CubeDivider(torch.tensor(temp),self.coor_factor)
      labels = labels+(sub_temp.divide_into_sub_cubes(),)

    images = torch.stack(images)
    labels = torch.stack(labels)
    images = images.reshape(-1,1,int(volume_shape[0]/self.coor_factor),int(volume_shape[1]/self.coor_factor),int(volume_shape[2]/self.coor_factor)).float()
    labels = labels.reshape(-1,1,int(volume_shape[0]/self.coor_factor),int(volume_shape[1]/self.coor_factor),int(volume_shape[2]/self.coor_factor))
    new_labels = ()
    for temp in labels:
      new_temp = ()
      for i in [0,1,2]:
        new_temp=new_temp+ (torch.mul(torch.tensor(np.asarray(temp == i, dtype=np.float64)),1),)
      new_temp = torch.stack(new_temp)
      new_labels = new_labels + (new_temp,)
    labels = torch.stack(new_labels)
    labels = labels.reshape(-1,3,int(volume_shape[0]/self.coor_factor),int(volume_shape[1]/self.coor_factor),int(volume_shape[2]/self.coor_factor))
    dataset = torch.utils.data.TensorDataset(images, labels)
    return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

#Pytorch training

A common metric for assessing the similarity between two sets or segmentations is dice score. With a value ranging from 0 to 1, it calculates the amount of overlap between the predicted and ground truth labels. An exact overlap is represented by a Dice score of 1, while no overlap is represented by a score of 0. By using the fudge factor, division by zero is avoided.

In [None]:
import numpy as np
import torch


def faster_dice(x, y, labels, fudge_factor=1e-8):
    """Faster PyTorch implementation of Dice scores.
    :param x: input label map as torch.Tensor
    :param y: input label map as torch.Tensor of the same size as x
    :param labels: list of labels to evaluate on
    :param fudge_factor: an epsilon value to avoid division by zero
    :return: pytorch Tensor with Dice scores in the same order as labels.
    """

    assert x.shape == y.shape, "both inputs should have same size, had {} and {}".format(
        x.shape, y.shape
    )

    if len(labels) > 1:

        dice_score = torch.zeros(len(labels))
        for label in labels:
            x_label = x == label
            y_label = y == label
            xy_label = (x_label & y_label).sum()
            dice_score[label] = (
                2 * xy_label / (x_label.sum() + y_label.sum() + fudge_factor)
            )

    else:
        dice_score = dice(x == labels[0], y == labels[0], fudge_factor=fudge_factor)

    return dice_score


def dice(x, y, fudge_factor=1e-8):
    """Implementation of dice scores for 0/1 numy array"""
    return 2 * torch.sum(x * y) / (torch.sum(x) + torch.sum(y) + fudge_factor)

* The trainer class is responsible for training a neural network model for image segmentation.
* It takes parameters such as the number of input channels, number of output classes, data loaders for training and validation, subvolume shape, number of epochs, path to the model checkpoint file, and learning rate.
* The constructor initializes the model, criterion (**CrossEntropyLoss**), optimizer (**RMSprop**), and other class variables.
* The train method trains the model for the specified number of epochs.
Within each epoch, it iterates over the training data, computes the loss and dice scores, performs backpropagation, and updates the model's parameters.
* After the training loop, it evaluates the model on the validation data, computes the loss and dice scores, and prints the training and validation metrics for each epoch.
* The **faster_dice** function is used to calculate the dice scores.

In [None]:
from torch.nn import functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class trainer:
  def __init__(self,n_channels, n_classes, trainloader, valloader, subvol_shape, epoches,modelpth,lrate=0.0007):
    self.n_channels = n_channels  # Number of input channels
    self.n_classes = n_classes # Number of output classes
    self.model = enMesh_checkpoint(self.n_channels, self.n_classes).to(device, dtype=torch.float32)
    self.criterion = nn.CrossEntropyLoss()
    self.lrate = lrate
    self.trainloader = trainloader
    self.valloader = valloader
    self.subvol_shape = subvol_shape
    self.epoches = epoches
    self.modelpth = modelpth
    self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.lrate)


  def train(self, num_epoches):
    try:
      self.model.load_state_dict(torch.load(self.modelpth))
    except:
      print('No valid pretained model.pth file mentioned')
    epoch =0
    train_loss = 0.0
    train_dice = 0.0
    val_loss = 0.0
    val_dice = 0.0
    while epoch != num_epoches :

        self.model.train()
        train_loss = 0.0
        for images, labels in self.trainloader:
          if 1 in torch.argmax(torch.squeeze(labels),0) or 2 in torch.argmax(torch.squeeze(labels),0):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)
            train_dice = 0.0
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss=self.criterion(outputs, labels)
            train_loss += loss.item()
            dice_scores = faster_dice(torch.argmax(torch.squeeze(outputs),0), torch.argmax(torch.squeeze(labels),0), labels=[0, 1, 2])  # Specify the labels to evaluate on
            train_dice += dice_scores.mean().item()  # Take the mean Dice score
            loss = loss+ (1-dice_scores.mean().item())
            loss.backward()
            self.optimizer.step()

        self.model.eval()
        val_loss = 0.0
        val_dice = 0.0
        with torch.no_grad():
            for images, labels in self.valloader:
                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                val_loss += loss.item()
                dice_scores = faster_dice(torch.argmax(torch.squeeze(outputs),0), torch.argmax(torch.squeeze(labels),0), labels=[0, 1, 2])
                val_dice += dice_scores.mean().item()


        train_loss /= len(self.trainloader)
        train_dice /= len(self.trainloader)
        val_loss /= len(self.valloader)
        val_dice /= len(self.valloader)

        print(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f} - Train Dice: {train_dice:.4f} - Val Loss: {val_loss:.4f} - Val Dice: {val_dice:.4f}")
        epoch = epoch+1

#Training



**Cumulative learning** -  Training in pytorch 4 separate cycles depending upon how images are divided into number of small subvolumes.

## Training Cycle 1 : subvolume (256,256,256)

In [None]:
traindata = DataLoaderClass('./data/dataset_train.csv',1,1).dataloader()
valdata = DataLoaderClass('./data/dataset_valid.csv',1,1).dataloader()
meshnet = trainer(1,3,traindata, valdata, [256,256,256], 20,'',0.0007)
meshnet.train(20)
torch.save(meshnet.model.state_dict(), 'meshnet.pth')

No valid pretained model.pth file mentioned
Epoch 1 - Train Loss: 0.9851 - Train Dice: 0.0381 - Val Loss: 0.7933 - Val Dice: 0.3471
Epoch 2 - Train Loss: 0.6630 - Train Dice: 0.0321 - Val Loss: 0.6158 - Val Dice: 0.3251
Epoch 3 - Train Loss: 0.5784 - Train Dice: 0.0321 - Val Loss: 0.5654 - Val Dice: 0.3248
Epoch 4 - Train Loss: 0.5324 - Train Dice: 0.0320 - Val Loss: 0.5203 - Val Dice: 0.3248
Epoch 5 - Train Loss: 0.5012 - Train Dice: 0.0321 - Val Loss: 0.4618 - Val Dice: 0.3248
Epoch 6 - Train Loss: 0.4746 - Train Dice: 0.0319 - Val Loss: 0.4433 - Val Dice: 0.3248
Epoch 7 - Train Loss: 0.4529 - Train Dice: 0.0322 - Val Loss: 0.4214 - Val Dice: 0.3248
Epoch 8 - Train Loss: 0.4355 - Train Dice: 0.0320 - Val Loss: 0.4150 - Val Dice: 0.3248
Epoch 9 - Train Loss: 0.4193 - Train Dice: 0.0322 - Val Loss: 0.3806 - Val Dice: 0.3248
Epoch 10 - Train Loss: 0.4045 - Train Dice: 0.0322 - Val Loss: 0.3737 - Val Dice: 0.3248
Epoch 11 - Train Loss: 0.3914 - Train Dice: 0.0322 - Val Loss: 0.3493 - Val

In [None]:
traindata=''
array=''
i=''
array1=''
images =''
img=''
prediciton=''
predicted=''
temp=''
labels=''
pred_peek=''
volume_shape=''
num_cubes=''
nifi_image=''
criterion=''
valdata=''
loaders=''
divider=''
sub_cubes=''
meshnet=''
model = ''
logdir=''
optimizer=''
scheduler=''
runner=''
tensor=''
reconstructed_tensor=''
import gc
torch.cuda.empty_cache()
gc.collect()

0

## Training Cycle 2 : subvolume (32,32,32)

In [None]:
traindata = DataLoaderClass('./data/dataset_train.csv',8,1).dataloader()
valdata = DataLoaderClass('./data/dataset_valid.csv',8,1).dataloader()
meshnet = trainer(1,3,traindata, valdata, [32,32,32], 20,'meshnet.pth',0.0007)
meshnet.train(20)
torch.save(meshnet.model.state_dict(), 'meshnet.pth')

In [None]:
traindata=''
array=''
i=''
array1=''
images =''
img=''
prediciton=''
predicted=''
temp=''
labels=''
pred_peek=''
volume_shape=''
num_cubes=''
nifi_image=''
criterion=''
valdata=''
loaders=''
divider=''
sub_cubes=''
meshnet=''
model = ''
logdir=''
optimizer=''
scheduler=''
runner=''
tensor=''
reconstructed_tensor=''
import gc
torch.cuda.empty_cache()
gc.collect()

## Training Cycle 3 : subvolume (64,64,64)

In [None]:
traindata = DataLoaderClass('./data/dataset_train.csv',4,1).dataloader()
valdata = DataLoaderClass('./data/dataset_valid.csv',4,1).dataloader()
meshnet = trainer(1,3,traindata, valdata, [64,64,64], 20,'meshnet.pth',0.0007)
meshnet.train(20)
torch.save(meshnet.model.state_dict(), 'meshnet.pth')

In [None]:
traindata=''
array=''
i=''
array1=''
images =''
img=''
prediciton=''
predicted=''
temp=''
labels=''
pred_peek=''
volume_shape=''
num_cubes=''
nifi_image=''
criterion=''
valdata=''
loaders=''
divider=''
sub_cubes=''
meshnet=''
model = ''
logdir=''
optimizer=''
scheduler=''
runner=''
tensor=''
reconstructed_tensor=''
import gc
torch.cuda.empty_cache()
gc.collect()

## Training Cycle 4 : subvolume (128,128,128)

In [None]:
traindata = DataLoaderClass('./data/dataset_train.csv',2,1).dataloader()
valdata = DataLoaderClass('./data/dataset_valid.csv',2,1).dataloader()
meshnet = trainer(1,3,traindata, valdata, [128,128,128], 20,'meshnet.pth',0.0007)
meshnet.train(20)
torch.save(meshnet.model.state_dict(), 'meshnet.pth')

In [None]:
traindata=''
array=''
i=''
array1=''
images =''
img=''
prediciton=''
predicted=''
temp=''
labels=''
pred_peek=''
volume_shape=''
num_cubes=''
nifi_image=''
criterion=''
valdata=''
loaders=''
divider=''
sub_cubes=''
meshnet=''
model = ''
logdir=''
optimizer=''
scheduler=''
runner=''
tensor=''
reconstructed_tensor=''
import gc
torch.cuda.empty_cache()
gc.collect()

#Model evaluation on inference dataset

In [None]:
from torch.nn import functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class evaluation:
  def __init__(self, modelpath, inferloader):
    self.inferloader = inferloader
    self.modelpath = modelpath
    self.model = enMesh_checkpoint(1, 3).to(device, dtype=torch.float32)
    self.criterion = nn.CrossEntropyLoss()

  def eval(self):
    try:
      self.model.load_state_dict(torch.load(self.modelpath))
      self.model.eval()
      infer_loss = 0.0
      infer_dice = 0.0
      with torch.no_grad():
        for images, labels in self.inferloader:
                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                infer_loss += loss.item()
                dice_scores = faster_dice(torch.argmax(torch.squeeze(outputs),0), torch.argmax(torch.squeeze(labels),0), labels=[0, 1, 2])
                infer_dice += dice_scores.mean().item()
      infer_loss /= len(self.inferloader)
      infer_dice /= len(self.inferloader)
      print('Loss :',infer_loss,'  Dice :',infer_dice)
    except Exception as e:
      print('No valid pretained model.pth file mentioned',e)


In [None]:
inferdata = DataLoaderClass('./data/dataset_infer.csv',1,1).dataloader()
modeval=evaluation('meshnet.pth',inferdata)
modeval.eval()

Loss : 0.2616010407606761   Dice : 0.3222907781600952


#Plots

In [None]:
inferdata = DataLoaderClass('./data/dataset_infer.csv',1,1).dataloader()

GroundTruth Label

In [None]:
prediciton = inferdata.dataset.tensors[1][0].reshape(-1,3,256,256,256)
predicted = torch.argmax(torch.squeeze(prediciton),0)
prediciton = predicted.reshape(256,256,256).numpy()
array1 = prediciton.astype(np.uint16)
nifi_image = nib.Nifti1Image(array1, affine=np.eye(4))  # Use identity affine matrix for simplicity
pred_peek= peek_class_new(nifi_image,1)
pred_peek.plots()

interactive(children=(IntSlider(value=127, description='X', max=255), IntSlider(value=127, description='Y', ma…

IntSlider(value=127, description='X', max=255)

IntSlider(value=127, description='Y', max=255)

IntSlider(value=127, description='Z', max=255)

Label predicted

In [None]:
model = enMesh_checkpoint(n_channels=1, n_classes=3)
model.load_state_dict(
    torch.load("meshnet.pth")
)
model.eval()

enMesh_checkpoint(
  (model): Sequential(
    (0): Sequential(
      (conv): Conv3d(1, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout): Dropout3d(p=0, inplace=False)
    )
    (1): Sequential(
      (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))
      (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout): Dropout3d(p=0, inplace=False)
    )
    (2): Sequential(
      (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))
      (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dropout): Dropout3d(p=0, inplace=False)
    )
    (3): Sequential(
      (conv): Conv3d(5, 5, kernel

In [None]:
prediciton = model(inferdata.dataset.tensors[0][0].reshape(-1,1,256,256,256))
predicted = torch.argmax(torch.squeeze(prediciton),0)
prediciton = predicted.reshape(256,256,256).numpy()
array = prediciton.astype(np.uint16)
nifi_image = nib.Nifti1Image(array, affine=np.eye(4))  # Use identity affine matrix for simplicity
pred_peek= peek_class_new(nifi_image,1)
pred_peek.plots()

# **Final notes**

This tutorial aims to provide a simple example of how to train the MeshNet model. However, it is worth noting that the actual brainchop models used in the tool have high accuracy. Therefore, thousands of MRI scans are needed during the training phase to achieve this level of accuracy. Unfortunately, the current capacity of Google Colab is insufficient to handle such a large dataset, and a cluster is required for the training process.