In [None]:
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
!git clone https://github.com/black0017/MedicalZooPytorch.git
!pip install pytorch_lightning
!pip install torchio
!pip install torchsummaryX
!pip install wandb

In [2]:
#Running this code and the cell below allows us to access files in drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
cd drive/MyDrive

/content/drive/MyDrive


In [4]:
#Creating our datasets, one with all image sequences in one and one with them separate
import os
import tqdm
import torchio as tio
subjects = []
subjects_separate = []
base_dir = './macai_datasets/brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
for file in tqdm.tqdm([file for file in os.listdir('./macai_datasets/brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData') if os.path.isdir(base_dir + file) == True]):
  #print(os.listdir(f'./macai_datasets/brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/{file}'))
  subject = tio.Subject(
        data = tio.ScalarImage(path = [base_dir+file +'/'+ file+'_flair.nii.gz', base_dir+file +'/'+ file+'_t1.nii.gz', base_dir+file +'/'+ file+'_t2.nii.gz', base_dir+file +'/'+ file + '_t1ce.nii.gz']),
        seg = tio.LabelMap(path=[base_dir+file  +'/'+ file+ '_seg.nii.gz'])
      )
  subject_separate = tio.Subject(
        t1 = tio.ScalarImage( path = [base_dir+file +'/'+ file+'_t1.nii.gz']),
        flair = tio.ScalarImage( path = [base_dir+file +'/'+ file+'_flair.nii.gz']),
        t2 = tio.ScalarImage( path = [base_dir+file +'/'+ file+'_t2.nii.gz']),
        t1ce = tio.ScalarImage( path = [base_dir+file +'/'+ file+'_t1ce.nii.gz']),
        seg = tio.LabelMap(path=[base_dir+file  +'/'+ file+ '_seg.nii.gz'])      
  )
  subjects_separate.append(subject_separate)
  subjects.append(subject)
dataset = tio.SubjectsDataset(subjects)
dataset_separate = tio.SubjectsDataset(subjects_separate)

100%|██████████| 368/368 [00:01<00:00, 287.31it/s]


In [5]:
#Insert your new transforms here
training_transform = tio.Compose([
    tio.CropOrPad((240, 240, 160)), 
    tio.OneHot(num_classes=5)
])

validation_transform = tio.Compose([
    tio.CropOrPad((240, 240, 160)),
    tio.OneHot(num_classes=5)    
    
])

In [6]:
#Splitting datasets into training and validation sets
import torch
training_split_ratio = 0.9
num_subjects = len(dataset)
num_training_subjects = int(training_split_ratio * num_subjects)
num_validation_subjects = num_subjects - num_training_subjects

num_split_subjects = num_training_subjects, num_validation_subjects
training_subjects, validation_subjects = torch.utils.data.random_split(subjects, num_split_subjects)
training_subjects_separate, validation_subjects_separate = torch.utils.data.random_split(subjects, num_split_subjects)

training_set = tio.SubjectsDataset(training_subjects, training_transform)
validation_set = tio.SubjectsDataset(validation_subjects, validation_transform)
training_set_separate = tio.SubjectsDataset(training_subjects_separate, training_transform)
validation_set_separate = tio.SubjectsDataset(validation_subjects_separate, validation_transform)
print('Training set:', len(training_set), 'subjects')
print('Validation set:', len(validation_set), 'subjects')
print('Training set:', len(training_set_separate), 'subjects')
print('Validation set:', len(validation_set_separate), 'subjects')

Training set: 331 subjects
Validation set: 37 subjects
Training set: 331 subjects
Validation set: 37 subjects


In [18]:
#Collate function
def col_fn(batch):
  out = dict()
  out['data'] = torch.stack([x['data']['data'].float() for x in batch])
  out['seg'] = torch.stack([x['seg']['data'].float() for x in batch])
  return out

In [19]:
import sys
sys.path.append('./MedicalZooPytorch')
import torch
from lib.medzoo.Unet3D import UNet3D
from lib.losses3D.basic import compute_per_channel_dice, expand_as_one_hot
import numpy as np
import pytorch_lightning as pl
import os
from torch.utils.data import Dataset, DataLoader, random_split
from pytorch_lightning.loggers import WandbLogger
import nibabel as nb
from skimage import transform
import matplotlib.pyplot as plt

class TumourSegmentation(pl.LightningModule):
  def __init__(self, learning_rate, in_channels=4,classes=(1,2,4)):
    super().__init__()
    self.model =  UNet3D(in_channels=in_channels, n_classes=len(classes), base_n_filter=8)
    self.learning_rate = learning_rate
    self.in_channels = in_channels
    self.classes = classes

  def forward(self,x):
    f = self.model.forward(x)
    return f

  def training_step(self, batch, batch_idx):
    x= batch['data']
    y = torch.cat([batch['seg'][:,1:3],batch['seg'][:,4].unsqueeze(dim=1)],dim = 1)
    y_hat = self.forward(x)

    loss = -1*compute_per_channel_dice(y_hat, y)
  # basic mean of all channels for now
  
    for i in range(len(self.classes)):
      if self.classes[i] == 1:
        self.log('train_loss_core',loss[i],prog_bar=True,logger=True)
      elif self.classes[i] == 2:
        self.log('train_loss_edema',loss[i],prog_bar=True,logger=True)
      elif self.classes[i] == 4:
        self.log('train_loss_enhancing',loss[i],prog_bar=True,logger=True)
    loss = torch.sum(loss)

    return loss

  def validation_step(self, batch, batch_idx):
    x= batch['data']
    y = torch.cat([batch['seg'][:,1:3],batch['seg'][:,4].unsqueeze(dim=1)],dim = 1)
    y_hat = self.forward(x)
  
  # basic mean of all channels for now
    loss = -1*compute_per_channel_dice(y_hat, y)
    for i in range(len(self.classes)):
      if self.classes[i] == 1:
        self.log('test_loss_core',loss[i],prog_bar=True,logger=True)
      elif self.classes[i] == 2:
        self.log('test_loss_edema',loss[i],prog_bar=True,logger=True)
      elif self.classes[i] == 4:
        self.log('test_loss_enhancing',loss[i],prog_bar=True,logger=True)
    loss = torch.sum(loss)
    return loss


  def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


In [20]:
model = TumourSegmentation(learning_rate = 5e-5)

In [21]:
wandb_logger = WandbLogger(project='macai',name='torchiotest', offline = False,reinit=True)
#Training
trainer = pl.Trainer(
    accumulate_grad_batches = 1,
    gpus=1,
    max_epochs = 10,
    precision=16,
    check_val_every_n_epoch = 1,
    logger = wandb_logger,
    log_every_n_steps=10,      
    val_check_interval= 50,
    progress_bar_refresh_rate=1 
)
train_set = torch.utils.data.DataLoader(training_set, batch_size=2, num_workers=2,shuffle=True,collate_fn= lambda x : col_fn(x))
val_set = torch.utils.data.DataLoader(validation_set, batch_size=2,num_workers=2,collate_fn= lambda x : col_fn(x))
trainer.fit(model,train_set,val_set)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.

  | Name  | Type   | Params
---------------------------------
0 | model | UNet3D | 1.8 M 
---------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.128     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

RuntimeError: ignored