--- a
+++ b/pytorch_HE_grade_classification.ipynb
@@ -0,0 +1,687 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "_uuid": "5e87694dcaf86e48e7423752ee27ffa3a7cc4525"
+   },
+   "source": [
+    "## Preparing libraries"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import os\n",
+    "import cv2\n",
+    "import matplotlib.pyplot as plt\n",
+    "from sklearn.model_selection import train_test_split,KFold,StratifiedKFold\n",
+    "import torch\n",
+    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='4,5,6,7'                        \n",
+    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
+    "from sklearn import preprocessing\n",
+    "\n",
+    "from torch.utils.data import TensorDataset, DataLoader,Dataset\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torchvision\n",
+    "import torchvision.transforms as transforms\n",
+    "import torch.optim as optim\n",
+    "from torch.optim import lr_scheduler\n",
+    "import torch.backends.cudnn as cudnn\n",
+    "import time \n",
+    "import tqdm\n",
+    "import random\n",
+    "from PIL import Image\n",
+    "train_on_gpu = True\n",
+    "from torch.utils.data.sampler import SubsetRandomSampler\n",
+    "from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR\n",
+    "from matplotlib.image import imread\n",
+    "from torch.utils import data\n",
+    "import seaborn as sns\n",
+    "\n",
+    "from itertools import cycle\n",
+    "\n",
+    "from sklearn import svm, datasets\n",
+    "from sklearn.metrics import roc_curve, auc,roc_auc_score,classification_report, confusion_matrix\n",
+    "from sklearn.preprocessing import label_binarize, StandardScaler\n",
+    "from sklearn.multiclass import OneVsRestClassifier, unique_labels\n",
+    "from scipy import interp\n",
+    "from utils import print_confusion_matrix\n",
+    "\n",
+    "from sklearn.utils import shuffle\n",
+    "import albumentations\n",
+    "from albumentations import torch as AT\n",
+    "import scipy.special\n",
+    "\n",
+    "from pytorchcv.model_provider import get_model as ptcv_get_model\n",
+    "\n",
+    "cudnn.benchmark = True\n",
+    "import warnings\n",
+    "warnings.filterwarnings(\"ignore\")\n",
+    "\n",
+    "from torch.utils.data import WeightedRandomSampler\n",
+    "from processing_pytorch import generate_dataset_grade, CancerDataset"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## define constant"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "INPUT_SHAPE = 224 ##for attention network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
+    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
+   },
+   "outputs": [],
+   "source": [
+    "SEED = 323\n",
+    "def seed_everything(seed=SEED):\n",
+    "    random.seed(seed)\n",
+    "    os.environ['PYHTONHASHSEED'] = str(seed)\n",
+    "    np.random.seed(seed)\n",
+    "    torch.manual_seed(seed)\n",
+    "    torch.cuda.manual_seed(seed)\n",
+    "    torch.backends.cudnn.deterministic = True\n",
+    "seed_everything(SEED)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "main_path = '.'"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "_uuid": "5f92a7b85b373990c8aff56efd61fc09193ba337"
+   },
+   "source": [
+    "## Preparing data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "X_train, y_train, train_paths = generate_dataset_grade(main_path,os.path.join(r'.\\\\','msi_data_train_grade.csv'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "X_test, y_test, test_paths = generate_dataset_grade(main_path,os.path.join(r'.\\\\','msi_data_test_grade.csv'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "X_valid, y_valid, valid_paths = generate_dataset_grade(main_path,os.path.join(r'.\\\\','msi_data_valid_grade.csv'))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### balance the training set by downsampling"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "index_h =list(np.where(y_train == 'non-dysplasia')[0])\n",
+    "index_lg =list(np.where(y_train == 'low grade')[0])\n",
+    "index_hg =list(np.where(y_train == 'high grade')[0])\n",
+    "max_len = len(index_hg)\n",
+    "len_h = len(index_h)\n",
+    "len_lg = len(index_lg)\n",
+    "\n",
+    "balanced_X_train = np.concatenate((np.array([x for i,x in enumerate(X_train) if i in index_h])[np.random.randint(0,len_h,max_len)], np.array([x for i,x in enumerate(X_train) if i in index_lg])[np.random.randint(0,len_lg,max_len)],np.array([x for i,x in enumerate(X_train) if i in index_hg])))\n",
+    "balanced_y_train = np.array(['non-dysplasia']*max_len + ['low grade']*max_len  + ['high grade']*max_len)\n",
+    "balanced_X_train,balanced_y_train = shuffle(balanced_X_train,balanced_y_train,random_state=SEED)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lb = preprocessing.LabelBinarizer()\n",
+    "bin_balanced_y_train = lb.fit_transform(balanced_y_train)\n",
+    "bin_y_test = lb.transform(y_test)\n",
+    "bin_y_val = lb.transform(y_valid)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "balanced_X_train=balanced_X_train.astype(np.uint8)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "balanced_y_train=balanced_y_train.astype(np.uint8)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "_uuid": "f2c6e4883abfb0869f277e512f343c54c70779d6"
+   },
+   "source": [
+    "## Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "74a6d1c81be396c2993e668f0375c16a095cd793"
+   },
+   "outputs": [],
+   "source": [
+    "data_transforms = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.RandomRotate90(p=0.5),\n",
+    "    albumentations.Transpose(p=0.5),\n",
+    "    albumentations.Flip(p=0.5),\n",
+    "    albumentations.OneOf([\n",
+    "        albumentations.CLAHE(clip_limit=2), albumentations.IAASharpen(), albumentations.IAAEmboss(), \n",
+    "        albumentations.RandomBrightness(), albumentations.RandomContrast(),\n",
+    "        albumentations.JpegCompression(), albumentations.Blur(), albumentations.GaussNoise()], p=0.5), \n",
+    "    albumentations.HueSaturationValue(p=0.5), \n",
+    "    albumentations.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.15, rotate_limit=45, p=0.5),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])\n",
+    "\n",
+    "data_transforms_test = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])\n",
+    "\n",
+    "data_transforms_tta0 = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.RandomRotate90(p=0.5),\n",
+    "    albumentations.Transpose(p=0.5),\n",
+    "    albumentations.Flip(p=0.5),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])\n",
+    "\n",
+    "data_transforms_tta1 = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.RandomRotate90(p=1),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])\n",
+    "\n",
+    "data_transforms_tta2 = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.Transpose(p=1),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])\n",
+    "\n",
+    "data_transforms_tta3 = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.Flip(p=1),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])\n",
+    "\n",
+    "dataset = CancerDataset(balanced_X_train, balanced_y_train,  transform=data_transforms)\n",
+    "test_set = CancerDataset(X_test, y_test,  transform=data_transforms_test)\n",
+    "val_set = CancerDataset(X_val, y_val,  transform=data_transforms_test)\n",
+    "batch_size = 16\n",
+    "num_workers = 0\n",
+    "# prepare data loaders (combine dataset and sampler)\n",
+    "train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=None, num_workers=num_workers)\n",
+    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, sampler=None, num_workers=num_workers)\n",
+    "valid_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, sampler=None, num_workers=num_workers)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "_uuid": "b137050e2ac7ad1229e7e8ace3bc3a3175340859"
+   },
+   "source": [
+    "## Training"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "1b8d3de18c8a43415de2c93494c436423d67ebc0"
+   },
+   "outputs": [],
+   "source": [
+    "model_conv = ptcv_get_model(\"cbam_resnet50\", pretrained=True)\n",
+    "model_conv.output = nn.Linear(in_features=2048, out_features=3, bias=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "5d6cefbb5bca7ee5c01c77ddf1e5d6199837582c"
+   },
+   "outputs": [],
+   "source": [
+    "model_conv.cuda()\n",
+    "criterion = nn.CrossEntropyLoss() #weight=class_weights.to(device)\n",
+    "\n",
+    "optimizer = optim.Adam(model_conv.parameters(), lr=0.0001)\n",
+    "\n",
+    "scheduler = StepLR(optimizer, 5, gamma=0.2)\n",
+    "scheduler.step()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if torch.cuda.device_count() > 1:\n",
+    "    print(torch.cuda.device_count() )\n",
+    "    model_conv = nn.DataParallel(model_conv,device_ids=[0,1,2,3])\n",
+    "model_conv.to(device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def multi_acc(y_pred, y_test):\n",
+    "    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)\n",
+    "    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    \n",
+    "    \n",
+    "    correct_pred = (y_pred_tags == y_test).float()\n",
+    "    acc = correct_pred.sum() / len(correct_pred)\n",
+    "    \n",
+    "    acc = torch.round(acc) \n",
+    "    \n",
+    "    return acc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "d384f23b71a56031ae141e5fa4e8e216c862749c",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "EPOCHS=20\n",
+    "train_acc_max = 0\n",
+    "print(\"Begin training.\")\n",
+    "for e in tqdm.tqdm(range(1, EPOCHS+1)):\n",
+    "    \n",
+    "    # TRAINING\n",
+    "    train_epoch_loss = 0\n",
+    "    train_epoch_acc = 0\n",
+    "    model_conv.train()\n",
+    "    for tr_batch_i ,(X_train_batch, y_train_batch)  in enumerate(train_loader):\n",
+    "        X_train_batch, y_train_batch = X_train_batch.to(device), y_train_batch.to(device)\n",
+    "        optimizer.zero_grad()\n",
+    "        \n",
+    "        y_train_pred = model_conv(X_train_batch)\n",
+    "        \n",
+    "        y_train_batch = y_train_batch.argmax(1)\n",
+    "        \n",
+    "        train_loss = criterion(y_train_pred, y_train_batch)\n",
+    "        train_acc = multi_acc(y_train_pred, y_train_batch)\n",
+    "        \n",
+    "        train_loss.backward()\n",
+    "        optimizer.step()\n",
+    "        \n",
+    "        train_epoch_loss += train_loss.item()\n",
+    "        train_epoch_acc += train_acc.item()\n",
+    "        \n",
+    "        \n",
+    "        # TESTING   \n",
+    "        if (tr_batch_i+1)%600 == 0: \n",
+    "            with torch.no_grad():\n",
+    "\n",
+    "                test_epoch_loss = 0\n",
+    "                test_epoch_acc = 0\n",
+    "\n",
+    "                model_conv.eval()\n",
+    "                for X_test_batch, y_test_batch in test_loader:\n",
+    "                    X_test_batch, y_test_batch = X_test_batch.to(device), y_test_batch.to(device)\n",
+    "\n",
+    "                    y_test_pred = model_conv(X_test_batch)\n",
+    "\n",
+    "                    y_test_batch = y_test_batch.argmax(1)\n",
+    "\n",
+    "                    test_loss = criterion(y_test_pred, y_test_batch)\n",
+    "                    test_acc = multi_acc(y_test_pred, y_test_batch)\n",
+    "\n",
+    "                    test_epoch_loss += test_loss.item()\n",
+    "                    test_epoch_acc += test_acc.item()\n",
+    "\n",
+    "                print(f'Epoch {e+0:03}: | Train Loss: {train_epoch_loss/len(train_loader):.5f} | Test Loss: {test_epoch_loss/len(test_loader):.5f} | Train Acc: {train_epoch_acc/len(train_loader):.3f}| Test Acc: {test_epoch_acc/len(test_loader):.3f}')\n",
+    "\n",
+    "                test_acc = test_epoch_acc/len(test_loader)\n",
+    "                train_acc = train_epoch_acc/len(train_loader)\n",
+    "                if train_acc > train_acc_max:\n",
+    "                    print('Training accuracy increased ({:.6f} --> {:.6f}).  Saving model ...'.format(\n",
+    "                    train_acc_max,\n",
+    "                    train_acc))\n",
+    "                    torch.save(model_conv.state_dict(), r\".\\models_attention-grade-classif//model_epoch_{}_test_{:.4f}.pt\".format(e ,(test_acc*100)))\n",
+    "                    train_acc_max = train_acc"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "_uuid": "964ed06f462a6c6e6503f5f4259f705e650066a4"
+   },
+   "source": [
+    "## Generate DL features tables"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "71258d16737b84ad8ef4b6f69960224f429d73e5"
+   },
+   "outputs": [],
+   "source": [
+    "model_conv.eval()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "d427380ffd08093b5bc2a751cba7ca08c75332d6"
+   },
+   "outputs": [],
+   "source": [
+    "saved_dict = torch.load(r'.') #load weights\n",
+    "model_conv.load_state_dict(saved_dict)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "new_classifier = nn.Sequential(*list(model_conv.children())[-1].features).cpu()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_transforms = albumentations.Compose([\n",
+    "    albumentations.Resize(INPUT_SHAPE, INPUT_SHAPE),\n",
+    "    albumentations.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),\n",
+    "    AT.ToTensor()\n",
+    "    ])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_features = {}\n",
+    "for i,temp_X in tqdm.tqdm(enumerate(X_train)):\n",
+    "    tensor_X=data_transforms(image=temp_X)\n",
+    "    tensor_X=tensor_X[\"image\"]\n",
+    "    tensor_X.unsqueeze_(0)\n",
+    "    output=torch.flatten(new_classifier(tensor_X)).detach().numpy()\n",
+    "    train_features[train_paths[i]]=output.flatten()\n",
+    "train_features=pd.DataFrame.from_dict(train_features)\n",
+    "train_features=train_features.T\n",
+    "train_features.to_csv(r\".\\train_features_grade_he.csv\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_features = {}\n",
+    "for i,temp_X in tqdm.tqdm(enumerate(X_test)):\n",
+    "    tensor_X=data_transforms(image=temp_X)\n",
+    "    tensor_X=tensor_X[\"image\"]\n",
+    "    tensor_X.unsqueeze_(0)\n",
+    "    output=torch.flatten(new_classifier(tensor_X)).detach().numpy()\n",
+    "    test_features[test_paths[i]]=output.flatten()\n",
+    "test_features=pd.DataFrame.from_dict(test_features)\n",
+    "test_features=test_features.T\n",
+    "test_features.to_csv(r\".\\test_features_grade_he.csv\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "valid_features = {}\n",
+    "for i,temp_X in tqdm.tqdm(enumerate(X_val)):\n",
+    "    tensor_X=data_transforms(image=temp_X)\n",
+    "    tensor_X=tensor_X[\"image\"]\n",
+    "    tensor_X.unsqueeze_(0)\n",
+    "    output=torch.flatten(new_classifier(tensor_X)).detach().numpy()\n",
+    "    valid_features[val_paths[i]]=output.flatten()\n",
+    "valid_features=pd.DataFrame.from_dict(valid_features)\n",
+    "valid_features=valid_features.T\n",
+    "valid_features.to_csv(r\".\\valid_features_grade_he.csv\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "_uuid": "5013176b39c0142ff6dfd2a3f013ea2ba2f2560c"
+   },
+   "source": [
+    "## TTA inference"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "591fa5e73bf4fedf11a15e50c5fcc788cbd4a0b8"
+   },
+   "outputs": [],
+   "source": [
+    "NUM_TTA = 10"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "_uuid": "9dacb064b3fbbf82a9ef49eb0a5eb1b33a74d06f"
+   },
+   "outputs": [],
+   "source": [
+    "sigmoid = lambda x: scipy.special.expit(x)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def def_tta(X_data,y_data):\n",
+    "    for num_tta in range(NUM_TTA):\n",
+    "        if num_tta==0:\n",
+    "            test_set = CancerDataset(X_data, y_data,  transform=data_transforms_test)\n",
+    "            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
+    "        elif num_tta==1:\n",
+    "            test_set = CancerDataset(X_data, y_data,  transform=data_transforms_tta1)\n",
+    "            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
+    "        elif num_tta==2:\n",
+    "            test_set = CancerDataset(X_data, y_data,  transform=data_transforms_tta2)\n",
+    "            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
+    "        elif num_tta==3:\n",
+    "            test_set = CancerDataset(X_data, y_data,  transform=data_transforms_tta3)\n",
+    "            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
+    "        elif num_tta<8:\n",
+    "            test_set = CancerDataset(X_data, y_data,  transform=data_transforms_tta0)\n",
+    "            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
+    "        else:\n",
+    "            test_set = CancerDataset(X_data, y_data,  transform=data_transforms)\n",
+    "            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers)\n",
+    "\n",
+    "        preds = []\n",
+    "        for batch_i, (data, target) in enumerate(test_loader):\n",
+    "            data, target = data.cuda(), target.cuda()\n",
+    "            output = model_conv(data).detach()\n",
+    "            pr = output.cpu().numpy()\n",
+    "            for i in pr:\n",
+    "                preds.append(sigmoid(i)/NUM_TTA)\n",
+    "        \n",
+    "        array_pred = np.array([item for sublist in preds for item in sublist ]).reshape(-1,3)\n",
+    "        \n",
+    "        \n",
+    "        if num_tta==0:\n",
+    "            test_preds = pd.DataFrame({'imgs': test_set.image_files_list, 'preds 0': array_pred[:,0], 'preds 1': array_pred[:,1], 'preds 2': array_pred[:,2]})\n",
+    "            test_preds['imgs'] = test_preds['imgs'].apply(lambda x: x.split('.')[0])\n",
+    "        else:\n",
+    "            test_preds['preds 0']=+array_pred[:,0]\n",
+    "            test_preds['preds 1']=+array_pred[:,1]\n",
+    "            test_preds['preds 2']=+array_pred[:,2]\n",
+    "            \n",
+    "        print(num_tta)\n",
+    "    return(test_preds)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_preds = def_tta(X_test,y_test)\n",
+    "valid_preds = def_tta(X_val,y_val)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## save results"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "df_test_preds =  pd.DataFrame(test_preds)\n",
+    "df_test_preds.to_csv(r'.\\sub_10_tta_test.csv', index=False)\n",
+    "df_test_labels = pd.DataFrame(y_test)\n",
+    "df_test_labels.to_csv(r'.\\labels_test.csv', index=False)\n",
+    "flatten_test_preds = np.array(test_preds[['preds 0','preds 1','preds 2']]).ravel()\n",
+    "np.savetxt(r'.\\sub_10_tta_test_flatten.txt',flatten_test_preds)\n",
+    "flatten_y_preds = bin_y_test.ravel()\n",
+    "np.savetxt(r'.\\labels_test_flatten.txt',flatten_y_preds)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "df_valid_preds =  pd.DataFrame(valid_preds)\n",
+    "df_valid_preds.to_csv(r'.\\sub_10_tta_valid.csv', index=False)\n",
+    "df_valid_labels = pd.DataFrame(y_valid)\n",
+    "df_valid_labels.to_csv(r'.\\labels_valid.csv', index=False)\n",
+    "flatten_valid_preds = np.array(valid_preds[['preds 0','preds 1','preds 2']]).ravel()\n",
+    "np.savetxt(r'.\\sub_10_tta_valid_flatten.txt',flatten_valid_preds)\n",
+    "flatten_y_preds = np.array(bin_y_val).ravel()\n",
+    "np.savetxt(r'.\\labels_valid_flatten.txt',flatten_y_preds)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}