--- a +++ b/utils.py @@ -0,0 +1,245 @@ +import pandas as pd +import numpy as np +import torch +import torch.optim.lr_scheduler as lr_scheduler +import torch.nn as nn +from sklearn.metrics import roc_auc_score +from lifelines.utils import concordance_index +from typing import Tuple +from math import ceil +import pickle +import scipy +import os + +def extract_csv(file): + '''From csv file path, returns the features and labels ''' + df = pd.read_csv(file) + return df + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def define_optimizer(args, model): + optimizer = None + if args.optimizer_type == 'adam': + # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=( + # args.beta1, args.beta2), weight_decay=args.weight_decay) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + elif args.optimizer_type == 'adagrad': + optimizer = torch.optim.Adagrad(model.parameters( + ), lr=args.lr, weight_decay=args.weight_decay, initial_accumulator_value=0.1) + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % args.optimizer) + return optimizer + + +def define_scheduler(args, optimizer): + if args.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + args.epoch_count - + args.niter) / float(100 + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif args.lr_policy == 'exp': + scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1) + elif args.lr_policy == 'step': + scheduler = lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_iters, gamma=0.1) + elif args.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau( + optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif args.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.niter, eta_min=0) + elif args.lr_policy == 'constant': + scheduler = lr_scheduler.ConstantLR(optimizer, factor=0.5, total_iters=1) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy) + return scheduler + +def custom_collate(batch): + # Sort the batch based on the "time" values + sorted_batch = sorted(batch, key=lambda x: x[3]) + + # Unpack the sorted batch + ct_tumor, ct_lymphnodes, y, time, event, ID = zip(*sorted_batch) + # print(ID[9]) + + + # Convert the sorted elements back to tensors + ct_tumor = torch.stack(ct_tumor) + ct_lymphnodes = torch.stack(ct_lymphnodes) + y = torch.tensor(y) + time = torch.tensor(time) + event = torch.tensor(event) + ID = list(ID) # Convert ID back to a list if needed + + + return ct_tumor, ct_lymphnodes, y, time, event, ID + + +def define_act_layer(act_type='relu'): + if act_type == 'tanh': + act_layer = nn.Tanh() + elif act_type == 'relu': + act_layer = nn.ReLU() + elif act_type == 'gelu': + act_layer = nn.GELU() + elif act_type == 'sigmoid': + act_layer = nn.Sigmoid() + elif act_type == 'LSM': + act_layer = nn.LogSoftmax(dim=1) + elif act_type == "none": + act_layer = None + else: + raise NotImplementedError( + 'activation layer [%s] is not found' % act_type) + return act_layer + + +def compute_metrics(args, preds): + preds_grade, preds_hazard, y, time, event, ID = preds + if args.task=="multitask": + preds_grade = preds_grade.cpu().detach().numpy() + y = y.cpu().detach().numpy() + preds_hazard = preds_hazard.cpu().detach().numpy() + time = time.cpu().detach().numpy() + event = event.cpu().detach().numpy() + # print(time) + # print(preds_hazard) + # print(event) + # print(ID) + + ci = concordance_index(time, -preds_hazard, event) + auc = roc_auc_score(y, preds_grade) + return ci, auc + elif args.task=="classification": + preds_grade = preds_grade.cpu().detach().numpy() + time = time.cpu().detach().numpy() + event = event.cpu().detach().numpy() + ci = concordance_index(time, -preds_grade, event) + return ci, 0 + elif args.task == "survival": + preds_hazard = preds_hazard.cpu().detach().numpy() + time = time.cpu().detach().numpy() + event = event.cpu().detach().numpy() + ci = concordance_index(time, -preds_hazard, event) + return ci, 0 + else: + raise NotImplementedError( + f'task method {args.task} is not implemented') + + + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + +CT_WINDOWS = { + "bone": (1800, 400), + "lung": (1500, -600), + "soft_tissue": (800, 50), + "default": (2048, 0) +} + + +def _w_to_t(ww: int, wl: int) -> Tuple[float, float]: + """Convert Window width / Window level. + Parameters + ---------- + ww : int + Window width + wl : int + Window level + Returns + ------- + Tuple[int,int] + Lower and upper threshold to use for clipping array values + """ + upper = wl + (ww / 2) + lower = wl - (ww / 2) + return lower, upper + + +def adjust_ct_window(image, ww, wl): + """Perform windows adjustement like a radiologist do to visualize its image. + + We also perform quantization, to be more robust to differences in images due to the scanning machine. + Concretely, once rescaled between 0 to 255, the values are converted to int8 (effectively removing all decimal values) + then converted back to float32 (for further processing by the model) + """ + window_min, window_max = _w_to_t(ww, wl) + + if isinstance(image, np.ndarray): + windowed_img = np.clip(image, window_min, window_max) + else: + raise + windowed_img = (windowed_img - window_min) / (window_max - window_min) + return windowed_img.astype(np.float32) + + +def lung_window(image): + return adjust_ct_window(image, *CT_WINDOWS["lung"]) + + +def bone_window(image): + return adjust_ct_window(image, *CT_WINDOWS["bone"]) + + +def soft_tissue_window(image): + return adjust_ct_window(image, *CT_WINDOWS["soft_tissue"]) + + +def default_window(image): + return adjust_ct_window(image, *CT_WINDOWS["default"]) + +def center_crop(img, dim): + + h, w, d = img.shape[0], img.shape[1], img.shape[2] + crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0] + crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1] + crop_depth = dim[2] if dim[2] < img.shape[2] else img.shape[2] + + mid_x, mid_y, mid_z = int(w/2), int(h/2), int(d/2) + cw2, ch2, cd2 = int(crop_width/2), int(crop_height/2), int(crop_depth/2) + crop_img = img[mid_y-ch2:mid_y+ch2, mid_x - + cw2:mid_x+cw2, mid_z-cd2:mid_z+cd2] + + return crop_img + + +def random_crop(img, dim, center): + + crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0] + crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1] + crop_depth = dim[2] if dim[2] < img.shape[2] else img.shape[2] + + mid_x, mid_y, mid_z = center[1], center[0], center[2] + cw2, ch2, cd2 = int(crop_width/2), int(crop_height/2), ceil(crop_depth/2) + crop_img = img[mid_y-ch2:mid_y+ch2, mid_x - + cw2:mid_x+cw2, mid_z-cd2:mid_z+cd2] + + + return crop_img + + +def save_results_to_mat(split, args, model_name): + file_path = os.path.join(args.checkpoints_dir, args.exp_name, model_name, f'pred_{split}.pkl') + data = pickle.load(open(file_path, "rb")) + + flattened_list = [item for sublist in data[5] for item in sublist] + IDs = np.asarray(flattened_list) + + matlab_dict = { + f'{split}_ID': IDs, + f'{split}_score': data[0].cpu().detach().numpy(), + f'{split}_surv': data[3].cpu().detach().numpy(), + f'{split}_censor': data[4].cpu().detach().numpy() + } + + mat_file_path = f"C:\\Users\\bsong47\\OneDrive - Emory University\\Documents\\MATLAB\\swinradiomic_{split}_data.mat" + scipy.io.savemat(mat_file_path, matlab_dict) \ No newline at end of file