import pandas as pd
import numpy as np
import torch
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
from typing import Tuple
from math import ceil
import pickle
import scipy
import os
def extract_csv(file):
'''From csv file path, returns the features and labels '''
df = pd.read_csv(file)
return df
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def define_optimizer(args, model):
optimizer = None
if args.optimizer_type == 'adam':
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(
# args.beta1, args.beta2), weight_decay=args.weight_decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer_type == 'adagrad':
optimizer = torch.optim.Adagrad(model.parameters(
), lr=args.lr, weight_decay=args.weight_decay, initial_accumulator_value=0.1)
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' % args.optimizer)
return optimizer
def define_scheduler(args, optimizer):
if args.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + args.epoch_count -
args.niter) / float(100 + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif args.lr_policy == 'exp':
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1)
elif args.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(
optimizer, step_size=args.lr_decay_iters, gamma=0.1)
elif args.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif args.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.niter, eta_min=0)
elif args.lr_policy == 'constant':
scheduler = lr_scheduler.ConstantLR(optimizer, factor=0.5, total_iters=1)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy)
return scheduler
def custom_collate(batch):
# Sort the batch based on the "time" values
sorted_batch = sorted(batch, key=lambda x: x[3])
# Unpack the sorted batch
ct_tumor, ct_lymphnodes, y, time, event, ID = zip(*sorted_batch)
# print(ID[9])
# Convert the sorted elements back to tensors
ct_tumor = torch.stack(ct_tumor)
ct_lymphnodes = torch.stack(ct_lymphnodes)
y = torch.tensor(y)
time = torch.tensor(time)
event = torch.tensor(event)
ID = list(ID) # Convert ID back to a list if needed
return ct_tumor, ct_lymphnodes, y, time, event, ID
def define_act_layer(act_type='relu'):
if act_type == 'tanh':
act_layer = nn.Tanh()
elif act_type == 'relu':
act_layer = nn.ReLU()
elif act_type == 'gelu':
act_layer = nn.GELU()
elif act_type == 'sigmoid':
act_layer = nn.Sigmoid()
elif act_type == 'LSM':
act_layer = nn.LogSoftmax(dim=1)
elif act_type == "none":
act_layer = None
else:
raise NotImplementedError(
'activation layer [%s] is not found' % act_type)
return act_layer
def compute_metrics(args, preds):
preds_grade, preds_hazard, y, time, event, ID = preds
if args.task=="multitask":
preds_grade = preds_grade.cpu().detach().numpy()
y = y.cpu().detach().numpy()
preds_hazard = preds_hazard.cpu().detach().numpy()
time = time.cpu().detach().numpy()
event = event.cpu().detach().numpy()
# print(time)
# print(preds_hazard)
# print(event)
# print(ID)
ci = concordance_index(time, -preds_hazard, event)
auc = roc_auc_score(y, preds_grade)
return ci, auc
elif args.task=="classification":
preds_grade = preds_grade.cpu().detach().numpy()
time = time.cpu().detach().numpy()
event = event.cpu().detach().numpy()
ci = concordance_index(time, -preds_grade, event)
return ci, 0
elif args.task == "survival":
preds_hazard = preds_hazard.cpu().detach().numpy()
time = time.cpu().detach().numpy()
event = event.cpu().detach().numpy()
ci = concordance_index(time, -preds_hazard, event)
return ci, 0
else:
raise NotImplementedError(
f'task method {args.task} is not implemented')
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
CT_WINDOWS = {
"bone": (1800, 400),
"lung": (1500, -600),
"soft_tissue": (800, 50),
"default": (2048, 0)
}
def _w_to_t(ww: int, wl: int) -> Tuple[float, float]:
"""Convert Window width / Window level.
Parameters
----------
ww : int
Window width
wl : int
Window level
Returns
-------
Tuple[int,int]
Lower and upper threshold to use for clipping array values
"""
upper = wl + (ww / 2)
lower = wl - (ww / 2)
return lower, upper
def adjust_ct_window(image, ww, wl):
"""Perform windows adjustement like a radiologist do to visualize its image.
We also perform quantization, to be more robust to differences in images due to the scanning machine.
Concretely, once rescaled between 0 to 255, the values are converted to int8 (effectively removing all decimal values)
then converted back to float32 (for further processing by the model)
"""
window_min, window_max = _w_to_t(ww, wl)
if isinstance(image, np.ndarray):
windowed_img = np.clip(image, window_min, window_max)
else:
raise
windowed_img = (windowed_img - window_min) / (window_max - window_min)
return windowed_img.astype(np.float32)
def lung_window(image):
return adjust_ct_window(image, *CT_WINDOWS["lung"])
def bone_window(image):
return adjust_ct_window(image, *CT_WINDOWS["bone"])
def soft_tissue_window(image):
return adjust_ct_window(image, *CT_WINDOWS["soft_tissue"])
def default_window(image):
return adjust_ct_window(image, *CT_WINDOWS["default"])
def center_crop(img, dim):
h, w, d = img.shape[0], img.shape[1], img.shape[2]
crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0]
crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1]
crop_depth = dim[2] if dim[2] < img.shape[2] else img.shape[2]
mid_x, mid_y, mid_z = int(w/2), int(h/2), int(d/2)
cw2, ch2, cd2 = int(crop_width/2), int(crop_height/2), int(crop_depth/2)
crop_img = img[mid_y-ch2:mid_y+ch2, mid_x -
cw2:mid_x+cw2, mid_z-cd2:mid_z+cd2]
return crop_img
def random_crop(img, dim, center):
crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0]
crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1]
crop_depth = dim[2] if dim[2] < img.shape[2] else img.shape[2]
mid_x, mid_y, mid_z = center[1], center[0], center[2]
cw2, ch2, cd2 = int(crop_width/2), int(crop_height/2), ceil(crop_depth/2)
crop_img = img[mid_y-ch2:mid_y+ch2, mid_x -
cw2:mid_x+cw2, mid_z-cd2:mid_z+cd2]
return crop_img
def save_results_to_mat(split, args, model_name):
file_path = os.path.join(args.checkpoints_dir, args.exp_name, model_name, f'pred_{split}.pkl')
data = pickle.load(open(file_path, "rb"))
flattened_list = [item for sublist in data[5] for item in sublist]
IDs = np.asarray(flattened_list)
matlab_dict = {
f'{split}_ID': IDs,
f'{split}_score': data[0].cpu().detach().numpy(),
f'{split}_surv': data[3].cpu().detach().numpy(),
f'{split}_censor': data[4].cpu().detach().numpy()
}
mat_file_path = f"C:\\Users\\bsong47\\OneDrive - Emory University\\Documents\\MATLAB\\swinradiomic_{split}_data.mat"
scipy.io.savemat(mat_file_path, matlab_dict)