688 lines (687 with data), 23.0 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"_uuid": "5e87694dcaf86e48e7423752ee27ffa3a7cc4525"
},
"source": [
"## Preparing libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split,KFold,StratifiedKFold\n",
"import torch\n",
"os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
"os.environ['CUDA_VISIBLE_DEVICES']='4,5,6,7' \n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"from sklearn import preprocessing\n",
"\n",
"from torch.utils.data import TensorDataset, DataLoader,Dataset\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"import torch.backends.cudnn as cudnn\n",
"import time \n",
"import tqdm\n",
"import random\n",
"from PIL import Image\n",
"train_on_gpu = True\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR\n",
"from matplotlib.image import imread\n",
"from torch.utils import data\n",
"import seaborn as sns\n",
"\n",
"from itertools import cycle\n",
"\n",
"from sklearn import svm, datasets\n",
"from sklearn.metrics import roc_curve, auc,roc_auc_score,classification_report, confusion_matrix\n",
"from sklearn.preprocessing import label_binarize, StandardScaler\n",
"from sklearn.multiclass import OneVsRestClassifier, unique_labels\n",
"from scipy import interp\n",
"from utils import print_confusion_matrix\n",
"\n",
"from sklearn.utils import shuffle\n",
"import albumentations\n",
"from albumentations import torch as AT\n",
"import scipy.special\n",
"\n",
"from pytorchcv.model_provider import get_model as ptcv_get_model\n",
"\n",
"cudnn.benchmark = True\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"from torch.utils.data import WeightedRandomSampler\n",
"from processing_pytorch import generate_dataset_grade, CancerDataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## define constant"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"INPUT_SHAPE = 224 ##for attention network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
"_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
},
"outputs": [],
"source": [
"SEED = 323\n",
"def seed_everything(seed=SEED):\n",
" random.seed(seed)\n",
" os.environ['PYHTONHASHSEED'] = str(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed(seed)\n",
" torch.backends.cudnn.deterministic = True\n",
"seed_everything(SEED)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"main_path = '.'"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "5f92a7b85b373990c8aff56efd61fc09193ba337"
},
"source": [
"## Preparing data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train, train_paths = generate_dataset_grade(main_path,os.path.join(r'.\\\\','msi_data_train_grade.csv'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_test, y_test, test_paths = generate_dataset_grade(main_path,os.path.join(r'.\\\\','msi_data_test_grade.csv'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_valid, y_valid, valid_paths = generate_dataset_grade(main_path,os.path.join(r'.\\\\','msi_data_valid_grade.csv'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### balance the training set by downsampling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"index_h =list(np.where(y_train == 'non-dysplasia')[0])\n",
"index_lg =list(np.where(y_train == 'low grade')[0])\n",
"index_hg =list(np.where(y_train == 'high grade')[0])\n",
"max_len = len(index_hg)\n",
"len_h = len(index_h)\n",
"len_lg = len(index_lg)\n",
"\n",
"balanced_X_train = np.concatenate((np.array([x for i,x in enumerate(X_train) if i in index_h])[np.random.randint(0,len_h,max_len)], np.array([x for i,x in enumerate(X_train) if i in index_lg])[np.random.randint(0,len_lg,max_len)],np.array([x for i,x in enumerate(X_train) if i in index_hg])))\n",
"balanced_y_train = np.array(['non-dysplasia']*max_len + ['low grade']*max_len + ['high grade']*max_len)\n",
"balanced_X_train,balanced_y_train = shuffle(balanced_X_train,balanced_y_train,random_state=SEED)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lb = preprocessing.LabelBinarizer()\n",
"bin_balanced_y_train = lb.fit_transform(balanced_y_train)\n",
"bin_y_test = lb.transform(y_test)\n",
"bin_y_val = lb.transform(y_valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"balanced_X_train=balanced_X_train.astype(np.uint8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"balanced_y_train=balanced_y_train.astype(np.uint8)"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "f2c6e4883abfb0869f277e512f343c54c70779d6"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "74a6d1c81be396c2993e668f0375c16a095cd793"
},
"outputs": [],
"source": [
"data_transforms = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.RandomRotate90(p=0.5),\n",
" albumentations.Transpose(p=0.5),\n",
" albumentations.Flip(p=0.5),\n",
" albumentations.OneOf([\n",
" albumentations.CLAHE(clip_limit=2), albumentations.IAASharpen(), albumentations.IAAEmboss(), \n",
" albumentations.RandomBrightness(), albumentations.RandomContrast(),\n",
" albumentations.JpegCompression(), albumentations.Blur(), albumentations.GaussNoise()], p=0.5), \n",
" albumentations.HueSaturationValue(p=0.5), \n",
" albumentations.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.15, rotate_limit=45, p=0.5),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])\n",
"\n",
"data_transforms_test = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])\n",
"\n",
"data_transforms_tta0 = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.RandomRotate90(p=0.5),\n",
" albumentations.Transpose(p=0.5),\n",
" albumentations.Flip(p=0.5),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])\n",
"\n",
"data_transforms_tta1 = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.RandomRotate90(p=1),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])\n",
"\n",
"data_transforms_tta2 = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.Transpose(p=1),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])\n",
"\n",
"data_transforms_tta3 = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.Flip(p=1),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])\n",
"\n",
"dataset = CancerDataset(balanced_X_train, balanced_y_train, transform=data_transforms)\n",
"test_set = CancerDataset(X_test, y_test, transform=data_transforms_test)\n",
"val_set = CancerDataset(X_val, y_val, transform=data_transforms_test)\n",
"batch_size = 16\n",
"num_workers = 0\n",
"# prepare data loaders (combine dataset and sampler)\n",
"train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=None, num_workers=num_workers)\n",
"test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=None, num_workers=num_workers)\n",
"valid_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, sampler=None, num_workers=num_workers)"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "b137050e2ac7ad1229e7e8ace3bc3a3175340859"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "1b8d3de18c8a43415de2c93494c436423d67ebc0"
},
"outputs": [],
"source": [
"model_conv = ptcv_get_model(\"cbam_resnet50\", pretrained=True)\n",
"model_conv.output = nn.Linear(in_features=2048, out_features=3, bias=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "5d6cefbb5bca7ee5c01c77ddf1e5d6199837582c"
},
"outputs": [],
"source": [
"model_conv.cuda()\n",
"criterion = nn.CrossEntropyLoss() #weight=class_weights.to(device)\n",
"\n",
"optimizer = optim.Adam(model_conv.parameters(), lr=0.0001)\n",
"\n",
"scheduler = StepLR(optimizer, 5, gamma=0.2)\n",
"scheduler.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if torch.cuda.device_count() > 1:\n",
" print(torch.cuda.device_count() )\n",
" model_conv = nn.DataParallel(model_conv,device_ids=[0,1,2,3])\n",
"model_conv.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def multi_acc(y_pred, y_test):\n",
" y_pred_softmax = torch.log_softmax(y_pred, dim = 1)\n",
" _, y_pred_tags = torch.max(y_pred_softmax, dim = 1) \n",
" \n",
" correct_pred = (y_pred_tags == y_test).float()\n",
" acc = correct_pred.sum() / len(correct_pred)\n",
" \n",
" acc = torch.round(acc) \n",
" \n",
" return acc"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "d384f23b71a56031ae141e5fa4e8e216c862749c",
"scrolled": true
},
"outputs": [],
"source": [
"EPOCHS=20\n",
"train_acc_max = 0\n",
"print(\"Begin training.\")\n",
"for e in tqdm.tqdm(range(1, EPOCHS+1)):\n",
" \n",
" # TRAINING\n",
" train_epoch_loss = 0\n",
" train_epoch_acc = 0\n",
" model_conv.train()\n",
" for tr_batch_i ,(X_train_batch, y_train_batch) in enumerate(train_loader):\n",
" X_train_batch, y_train_batch = X_train_batch.to(device), y_train_batch.to(device)\n",
" optimizer.zero_grad()\n",
" \n",
" y_train_pred = model_conv(X_train_batch)\n",
" \n",
" y_train_batch = y_train_batch.argmax(1)\n",
" \n",
" train_loss = criterion(y_train_pred, y_train_batch)\n",
" train_acc = multi_acc(y_train_pred, y_train_batch)\n",
" \n",
" train_loss.backward()\n",
" optimizer.step()\n",
" \n",
" train_epoch_loss += train_loss.item()\n",
" train_epoch_acc += train_acc.item()\n",
" \n",
" \n",
" # TESTING \n",
" if (tr_batch_i+1)%600 == 0: \n",
" with torch.no_grad():\n",
"\n",
" test_epoch_loss = 0\n",
" test_epoch_acc = 0\n",
"\n",
" model_conv.eval()\n",
" for X_test_batch, y_test_batch in test_loader:\n",
" X_test_batch, y_test_batch = X_test_batch.to(device), y_test_batch.to(device)\n",
"\n",
" y_test_pred = model_conv(X_test_batch)\n",
"\n",
" y_test_batch = y_test_batch.argmax(1)\n",
"\n",
" test_loss = criterion(y_test_pred, y_test_batch)\n",
" test_acc = multi_acc(y_test_pred, y_test_batch)\n",
"\n",
" test_epoch_loss += test_loss.item()\n",
" test_epoch_acc += test_acc.item()\n",
"\n",
" print(f'Epoch {e+0:03}: | Train Loss: {train_epoch_loss/len(train_loader):.5f} | Test Loss: {test_epoch_loss/len(test_loader):.5f} | Train Acc: {train_epoch_acc/len(train_loader):.3f}| Test Acc: {test_epoch_acc/len(test_loader):.3f}')\n",
"\n",
" test_acc = test_epoch_acc/len(test_loader)\n",
" train_acc = train_epoch_acc/len(train_loader)\n",
" if train_acc > train_acc_max:\n",
" print('Training accuracy increased ({:.6f} --> {:.6f}). Saving model ...'.format(\n",
" train_acc_max,\n",
" train_acc))\n",
" torch.save(model_conv.state_dict(), r\".\\models_attention-grade-classif//model_epoch_{}_test_{:.4f}.pt\".format(e ,(test_acc*100)))\n",
" train_acc_max = train_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "964ed06f462a6c6e6503f5f4259f705e650066a4"
},
"source": [
"## Generate DL features tables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "71258d16737b84ad8ef4b6f69960224f429d73e5"
},
"outputs": [],
"source": [
"model_conv.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "d427380ffd08093b5bc2a751cba7ca08c75332d6"
},
"outputs": [],
"source": [
"saved_dict = torch.load(r'.') #load weights\n",
"model_conv.load_state_dict(saved_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"new_classifier = nn.Sequential(*list(model_conv.children())[-1].features).cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_transforms = albumentations.Compose([\n",
" albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
" albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
" AT.ToTensor()\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_features = {}\n",
"for i,temp_X in tqdm.tqdm(enumerate(X_train)):\n",
" tensor_X=data_transforms(image=temp_X)\n",
" tensor_X=tensor_X[\"image\"]\n",
" tensor_X.unsqueeze_(0)\n",
" output=torch.flatten(new_classifier(tensor_X)).detach().numpy()\n",
" train_features[train_paths[i]]=output.flatten()\n",
"train_features=pd.DataFrame.from_dict(train_features)\n",
"train_features=train_features.T\n",
"train_features.to_csv(r\".\\train_features_grade_he.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_features = {}\n",
"for i,temp_X in tqdm.tqdm(enumerate(X_test)):\n",
" tensor_X=data_transforms(image=temp_X)\n",
" tensor_X=tensor_X[\"image\"]\n",
" tensor_X.unsqueeze_(0)\n",
" output=torch.flatten(new_classifier(tensor_X)).detach().numpy()\n",
" test_features[test_paths[i]]=output.flatten()\n",
"test_features=pd.DataFrame.from_dict(test_features)\n",
"test_features=test_features.T\n",
"test_features.to_csv(r\".\\test_features_grade_he.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"valid_features = {}\n",
"for i,temp_X in tqdm.tqdm(enumerate(X_val)):\n",
" tensor_X=data_transforms(image=temp_X)\n",
" tensor_X=tensor_X[\"image\"]\n",
" tensor_X.unsqueeze_(0)\n",
" output=torch.flatten(new_classifier(tensor_X)).detach().numpy()\n",
" valid_features[val_paths[i]]=output.flatten()\n",
"valid_features=pd.DataFrame.from_dict(valid_features)\n",
"valid_features=valid_features.T\n",
"valid_features.to_csv(r\".\\valid_features_grade_he.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "5013176b39c0142ff6dfd2a3f013ea2ba2f2560c"
},
"source": [
"## TTA inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "591fa5e73bf4fedf11a15e50c5fcc788cbd4a0b8"
},
"outputs": [],
"source": [
"NUM_TTA = 10"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"_uuid": "9dacb064b3fbbf82a9ef49eb0a5eb1b33a74d06f"
},
"outputs": [],
"source": [
"sigmoid = lambda x: scipy.special.expit(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def def_tta(X_data,y_data):\n",
" for num_tta in range(NUM_TTA):\n",
" if num_tta==0:\n",
" test_set = CancerDataset(X_data, y_data, transform=data_transforms_test)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
" elif num_tta==1:\n",
" test_set = CancerDataset(X_data, y_data, transform=data_transforms_tta1)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
" elif num_tta==2:\n",
" test_set = CancerDataset(X_data, y_data, transform=data_transforms_tta2)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
" elif num_tta==3:\n",
" test_set = CancerDataset(X_data, y_data, transform=data_transforms_tta3)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
" elif num_tta<8:\n",
" test_set = CancerDataset(X_data, y_data, transform=data_transforms_tta0)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
" else:\n",
" test_set = CancerDataset(X_data, y_data, transform=data_transforms)\n",
" test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
"\n",
" preds = []\n",
" for batch_i, (data, target) in enumerate(test_loader):\n",
" data, target = data.cuda(), target.cuda()\n",
" output = model_conv(data).detach()\n",
" pr = output.cpu().numpy()\n",
" for i in pr:\n",
" preds.append(sigmoid(i)/NUM_TTA)\n",
" \n",
" array_pred = np.array([item for sublist in preds for item in sublist ]).reshape(-1,3)\n",
" \n",
" \n",
" if num_tta==0:\n",
" test_preds = pd.DataFrame({'imgs': test_set.image_files_list, 'preds 0': array_pred[:,0], 'preds 1': array_pred[:,1], 'preds 2': array_pred[:,2]})\n",
" test_preds['imgs'] = test_preds['imgs'].apply(lambda x: x.split('.')[0])\n",
" else:\n",
" test_preds['preds 0']=+array_pred[:,0]\n",
" test_preds['preds 1']=+array_pred[:,1]\n",
" test_preds['preds 2']=+array_pred[:,2]\n",
" \n",
" print(num_tta)\n",
" return(test_preds)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_preds = def_tta(X_test,y_test)\n",
"valid_preds = def_tta(X_val,y_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## save results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"df_test_preds = pd.DataFrame(test_preds)\n",
"df_test_preds.to_csv(r'.\\sub_10_tta_test.csv', index=False)\n",
"df_test_labels = pd.DataFrame(y_test)\n",
"df_test_labels.to_csv(r'.\\labels_test.csv', index=False)\n",
"flatten_test_preds = np.array(test_preds[['preds 0','preds 1','preds 2']]).ravel()\n",
"np.savetxt(r'.\\sub_10_tta_test_flatten.txt',flatten_test_preds)\n",
"flatten_y_preds = bin_y_test.ravel()\n",
"np.savetxt(r'.\\labels_test_flatten.txt',flatten_y_preds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_valid_preds = pd.DataFrame(valid_preds)\n",
"df_valid_preds.to_csv(r'.\\sub_10_tta_valid.csv', index=False)\n",
"df_valid_labels = pd.DataFrame(y_valid)\n",
"df_valid_labels.to_csv(r'.\\labels_valid.csv', index=False)\n",
"flatten_valid_preds = np.array(valid_preds[['preds 0','preds 1','preds 2']]).ravel()\n",
"np.savetxt(r'.\\sub_10_tta_valid_flatten.txt',flatten_valid_preds)\n",
"flatten_y_preds = np.array(bin_y_val).ravel()\n",
"np.savetxt(r'.\\labels_valid_flatten.txt',flatten_y_preds)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}