Diff of /code/train_student.py [000000] .. [118be8]

Switch to side-by-side view

--- a
+++ b/code/train_student.py
@@ -0,0 +1,187 @@
+import pandas as pd
+from mil_data_generator import *
+from mil_models_pytorch import*
+from mil_trainer_torch import *
+from sklearn.utils.class_weight import compute_class_weight
+import torch
+from torchvision import transforms
+import argparse
+
+np.random.seed(42)
+random.seed(42)
+torch.manual_seed(42)
+
+
+def main(args):
+
+    # INPUTS #
+    dir_images = '../data/SICAP_MIL/patches/'
+    dir_data_frame = '../data/SICAP_MIL/dataframes/gt_global_slides.xlsx'
+    dir_data_frame_test = '../data/SICAP_MIL/dataframes/gt_test_patches.xlsx'
+    dir_experiment = '../data/results/' + args.experiment_name + '/'
+
+    classes = ['G3', 'G4', 'G5']
+    proportions = ['pG3', 'pG4', 'pG5']
+    input_shape = (3, 224, 224)
+    images_on_ram = True
+    data_augmentation = True
+    pMIL = False
+    aggregation = 'max'  # 'max', 'mean', 'attentionMIL', 'mcAttentionMIL'
+    mode = 'instance'  # 'embedding', 'instance', 'mixed'
+    include_background = True
+    iterations = 3
+
+    df = pd.read_excel(dir_data_frame)
+
+    metrics = []
+    for ii_iteration in np.arange(0, iterations):
+
+        # Set data generators
+        dataset_train = MILDataset(dir_images, df[df['Partition'] == 'train'], classes, bag_id='slide_name',
+                                   input_shape=input_shape, data_augmentation=False, images_on_ram=images_on_ram,
+                                   pMIL=pMIL, proportions=proportions)
+        data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512)
+
+        dataset_test = MILDataset(dir_images, df[df['Partition'] == 'test'], classes, bag_id='slide_name',
+                                  input_shape=input_shape, data_augmentation=False, images_on_ram=images_on_ram,
+                                  pMIL=pMIL, proportions=proportions, dataframe_instances=pd.read_excel(dir_data_frame_test))
+        data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512)
+
+        # Test at instance level
+        X_test = data_generator_test.dataset.X[data_generator_test.dataset.y_instances[:, 0] != -1, :, :, :]
+        Y_test = data_generator_test.dataset.y_instances[data_generator_test.dataset.y_instances[:, 0] != -1, :]
+
+        # Load network
+        network = torch.load(dir_experiment + str(ii_iteration) + '_network_weights_best.pth')
+
+        # Pseudolabels on training set
+        labels = []
+        yhat_one_hot = []
+        Yglobal = data_generator_train.dataset.Yglobal
+        X = data_generator_train.dataset.X
+
+        for i in np.arange(0, X.shape[0]):
+            print(str(i + 1) + '/' + str(X.shape[0]), end='\r')
+
+            # Tensorize input
+            x = torch.tensor(X[i, :, :, :]).cuda().float()
+            x = x.unsqueeze(0)
+
+            features = network.bb(x)
+            yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0)
+            yhat = yhat.detach().cpu().numpy()
+            yhat_one_hot.append(yhat)
+
+            if np.max(Yglobal[i, 1:]) == 0:
+                labels.append(0)
+            else:
+                if np.argmax(yhat) > 0:
+                    if Yglobal[i, np.argmax(yhat)] == 1 and yhat[np.argmax(yhat)] > 0.5:
+                        labels.append(np.argmax(yhat))
+                    else:
+                        labels.append(10)
+                else:
+                    labels.append(10)
+        labels = np.array(labels)
+        yhat_one_hot = np.array(yhat_one_hot)
+
+        X = X[labels != 10, :, :, :]
+        Y = labels[labels != 10]
+        images_id = np.array(dataset_train.images)[labels != 10]
+        class_weights = compute_class_weight('balanced', [0, 1, 2, 3], Y)
+
+        # Set student network architecture
+        lr = 1e-2
+        network = MILArchitecture(classes, mode=mode, aggregation=aggregation, backbone='vgg19',
+                                  include_background=include_background).cuda()
+        opt = torch.optim.SGD(network.parameters(), lr=lr)
+
+        tranf = torch.nn.Sequential(transforms.RandomHorizontalFlip(),
+                                    transforms.RandomRotation(degrees=(-45, 45)),
+                                    transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
+                                    transforms.ColorJitter(brightness=.5, hue=.3)).cuda()
+
+        training_data = CustomImageDataset(X, Y, transform=False)
+        train_dataloader = CustomGenerator(training_data, bs=32, shuffle=True)
+
+
+        def test_instances(X, Y, network, dir_out, i_iteration):
+            network.eval()
+            Yhat = []
+            for i in np.arange(0, X.shape[0]):
+                print(str(i + 1) + '/' + str(X.shape[0]), end='\r')
+
+                # Tensorize input
+                x = torch.tensor(X[i, :, :, :]).cuda().float()
+                x = x.unsqueeze(0)
+
+                features = network.bb(x)
+                yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0)
+                yhat = torch.argmax(yhat).detach().cpu().numpy()
+
+                Yhat.append(yhat)
+
+            Yhat = np.array(Yhat)
+            Y = np.argmax(Y, 1)
+
+            cr = classification_report(Y, Yhat, target_names=['NC'] + classes, digits=4)
+            cm = confusion_matrix(Y, Yhat)
+            k2 = cohen_kappa_score(Y, Yhat, weights='quadratic')
+
+            f = open(dir_out + str(i_iteration) + '_report_student.txt', 'w')
+            f.write('Title\n\nClassification Report\n\n{}\n\nConfusion Matrix\n\n{}\n\nKappa\n\n{}\n'.format(cr, cm, k2))
+            f.close()
+
+            return k2
+
+        # STUDENT TRAINING
+
+
+        epochs = 60
+        dropout_rate = 0.2
+        for i_epoch in np.arange(0, epochs):
+            l_epoch = 0
+
+            if (i_epoch + 1) % 25 == 0:
+                for g in opt.param_groups:
+                    g['lr'] = g['lr'] / 2
+
+            for i_iteration, (X, Y) in enumerate(train_dataloader):
+                # Set model to training mode and clear gradients
+                network.train()
+                opt.zero_grad()
+
+                X = X.cuda().float()
+                X = tranf(X)
+
+                logits = network.classifier(torch.nn.Dropout(dropout_rate)(torch.squeeze(network.bb(X))))
+
+                L = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights).cuda().float())(logits, Y.type(torch.LongTensor).cuda())
+
+                L.backward()
+                opt.step()
+
+                L_iteration = L.detach().cpu().numpy()
+                l_epoch += L_iteration
+
+                info = "[INFO] Epoch {}/{}  -- Step {}/{}: Lce={:.6f}".format(
+                    i_epoch + 1, epochs, i_iteration + 1, len(train_dataloader), L_iteration)
+                print(info, end='\r')
+
+            l_epoch = l_epoch/len(train_dataloader)
+
+            k2 = test_instances(X_test, Y_test, network, dir_experiment, ii_iteration)
+
+            info = "[INFO] Epoch {}/{}  -- Step {}/{}: Lce={:.6f}; k2={:.6f}".format(
+                i_epoch+1, epochs, i_iteration, len(train_dataloader), l_epoch, k2)
+            print(info, end='\n')
+
+        torch.save(network, dir_experiment + str(ii_iteration) + '_student_network_weights.pth')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--experiment_name", default="test_test_test", type=str)
+
+    args = parser.parse_args()
+    main(args)