--- a
+++ b/Learning/train.py
@@ -0,0 +1,152 @@
+import argparse
+
+import os
+import random
+import pandas as pd
+import numpy as np
+import time
+
+import torch
+
+from load_data import VoxelDataset
+from torch.utils.data import DataLoader, Subset
+
+from model import DeepDrug3D
+
+from sklearn.metrics import confusion_matrix
+from sklearn.model_selection import StratifiedKFold
+from sklearn.metrics import accuracy_score
+
+seed = 12306
+random.seed(seed)
+torch.manual_seed(seed)
+if torch.cuda.is_available():
+    torch.cuda.manual_seed_all(seed)
+    
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+print('Current device: ' + str(device))
+
+def main(opt):
+    in_channel = 14
+    model = DeepDrug3D(in_channel)
+    print(model)
+    model = model.to(device)
+    criterion = torch.nn.CrossEntropyLoss()
+    
+    if opt.opath is None:
+        os.mkdir('./logs')
+        opt.opath = './logs'
+    
+    labels = pd.read_csv(opt.lpath)
+    xid = labels['id'].tolist()
+    ys = labels['class'].tolist()
+    dataset = VoxelDataset(label_file=opt.lpath, root_dir=opt.path)
+    kfold = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed)
+    bs = opt.bs
+    f_cnt = 0
+    for train_id, val_id in kfold.split(xid, ys):
+        train_set = Subset(dataset, train_id)
+        train_loader = DataLoader(train_set, batch_size=bs, shuffle=True)
+        val_set = Subset(dataset, val_id)
+        val_loader = DataLoader(val_set, batch_size=bs, shuffle=True)
+        
+        tr_losses = np.zeros((opt.epoch,))
+        tr_accs = np.zeros((opt.epoch,))
+        val_losses = np.zeros((opt.epoch,))
+        val_accs = np.zeros((opt.epoch,))
+        
+        model.reset_parameters()
+        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
+        
+        best_val_loss = 1e6
+        
+        print('===================Fold {} starts==================='.format(f_cnt+1))
+        for epoch in range(opt.epoch):
+            s = time.time()
+            
+            model.train()
+            losses = 0
+            acc = 0
+            
+            for i, sampled_batch in enumerate(train_loader):
+                data = sampled_batch['voxel']
+                y = sampled_batch['label'].squeeze()
+                data = data.type(torch.FloatTensor)
+                if in_channel == 1:
+                   data = torch.unsqueeze(data,1)
+                y = y.to(device)
+                data = data.to(device)
+                optimizer.zero_grad()
+                output = model(data)
+                loss = criterion(output, y)
+                loss.backward()
+                optimizer.step()
+                
+                y_true = y.cpu().numpy()
+                y_pred = output.data.cpu().numpy().argmax(axis=1)
+                acc += accuracy_score(y_true, y_pred)*100
+                losses += loss.data.cpu().numpy()
+                
+            tr_losses[epoch] = losses/(i+1)
+            tr_accs[epoch] = acc/(i+1)
+            
+            model.eval()
+            v_losses = 0
+            v_acc = 0
+            y_preds = []
+            y_trues = []
+            
+            for j, sampled_batch in enumerate(val_loader):
+                data = sampled_batch['voxel']
+                y = sampled_batch['label'].squeeze()
+                data = data.type(torch.FloatTensor)
+                if in_channel == 1:
+                   data = torch.unsqueeze(data,1)
+                y = y.to(device)
+                data = data.to(device)
+                with torch.no_grad():
+                    output = model(data)
+                    loss = criterion(output, y)
+                
+                y_pred = output.data.cpu().numpy().argmax(axis=1)
+                y_true = y.cpu().numpy()
+                y_trues += y_true.tolist()
+                y_preds += y_pred.tolist()
+                v_acc += accuracy_score(y_true, y_pred)*100
+                v_losses += loss.data.cpu().numpy()
+                
+            cnf = confusion_matrix(y_trues, y_preds)        
+            val_losses[epoch] = v_losses/(j+1)
+            val_accs[epoch] = v_acc/(j+1)
+            
+            current_val_loss = v_losses/(j+1)
+            if current_val_loss < best_val_loss:
+                best_val_loss = current_val_loss
+                torch.save(model.state_dict(), os.path.join(opt.opath, 'best_model_fold_{}.ckpt'.format(f_cnt+1)))
+            
+            print('Epoch: {:03d} | time: {:.4f} seconds\n'
+                  'Train Loss: {:.4f} | Train accuracy {:.4f}\n'
+                  'Validation Loss: {:.4f} | Validation accuracy {:.4f} | Best {:.4f}'.format(epoch+1, time.time()-s, losses/(i+1),
+                                    acc/(i+1), v_losses/(j+1), v_acc/(j+1), best_val_loss))
+            print('Validation confusion matrix:')
+            print(cnf)
+
+        print('===================Fold {} ends==================='.format(f_cnt+1))
+        np.save(os.path.join(opt.opath, 'train_loss_{}.npy'.format(f_cnt+1)), tr_losses)
+        np.save(os.path.join(opt.opath, 'train_acc_{}.npy'.format(f_cnt+1)), tr_accs)
+        np.save(os.path.join(opt.opath, 'val_loss_{}.npy'.format(f_cnt+1)), val_losses)
+        np.save(os.path.join(opt.opath, 'val_acc_{}.npy'.format(f_cnt+1)), val_accs)
+    
+        f_cnt += 1
+        
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--path', type=str, required=True, help='path to data folder')
+    parser.add_argument('--lpath', type=str, required=True, help='path to label file')
+    parser.add_argument('--opath', type=str, required=False, help='output folder name')
+    parser.add_argument('--bs', type=int, required=True, help='batch size')
+    parser.add_argument('--lr', type=float, required=True, help='learning rate')
+    parser.add_argument('--epoch', type=int, required=True, help='number of epochs to train for')
+    
+    opt = parser.parse_args()
+    main(opt)