Diff of /utils/utils.py [000000] .. [4cd6c8]

Switch to side-by-side view

--- a
+++ b/utils/utils.py
@@ -0,0 +1,192 @@
+import pickle
+import torch
+import numpy as np
+import torch.nn as nn
+import pdb
+
+import torch
+import numpy as np
+import torch.nn as nn
+from torchvision import transforms
+from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
+import torch.optim as optim
+import pdb
+import torch.nn.functional as F
+import math
+from itertools import islice
+import collections
+device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+class SubsetSequentialSampler(Sampler):
+	"""Samples elements sequentially from a given list of indices, without replacement.
+
+	Arguments:
+		indices (sequence): a sequence of indices
+	"""
+	def __init__(self, indices):
+		self.indices = indices
+
+	def __iter__(self):
+		return iter(self.indices)
+
+	def __len__(self):
+		return len(self.indices)
+
+def collate_MIL_mtl_sex(batch):
+	img = torch.cat([item[0] for item in batch], dim = 0)
+	label = torch.LongTensor([item[1] for item in batch])
+	site = torch.LongTensor([item[2] for item in batch])
+	sex = torch.LongTensor([item[3] for item in batch])
+	# for item in batch:
+	# 	print(item)
+	return [img, label, site, sex]
+
+def collate_MIL_mtl(batch):
+	img = torch.cat([item[0] for item in batch], dim = 0)
+	label_task1 = torch.LongTensor([item[1] for item in batch])
+	label_task2 = torch.LongTensor([item[2] for item in batch])
+	label_task3 = torch.LongTensor([item[3] for item in batch])
+	# for item in batch:
+	# 	print(item)
+	return [img, label_task1, label_task2, label_task3]
+
+def collate_MIL(batch):
+	img = torch.cat([item[0] for item in batch], dim = 0)
+	label = torch.LongTensor([item[1] for item in batch])
+	return [img, label]
+
+def collate_features(batch):
+	img = torch.cat([item[0] for item in batch], dim = 0)
+	coords = np.vstack([item[1] for item in batch])
+	return [img, coords]
+
+
+collate_dict = {'MIL': collate_MIL, 'MIL_mtl': collate_MIL_mtl, 'MIL_mtl_sex': collate_MIL_mtl_sex, 'MIL_sex': collate_MIL_mtl}
+
+def get_simple_loader(dataset, batch_size=1, collate_fn='MIL'):
+	kwargs = {'num_workers': 32} if device.type == "cuda" else {}
+	collate = collate_dict[collate_fn]
+	loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate, **kwargs)
+	return loader
+
+def get_split_loader(split_dataset, training = False, testing = False, weighted = False, collate_fn='MIL'):
+	"""
+		return either the validation loader or training loader
+	"""
+	collate = collate_dict[collate_fn]
+
+	kwargs = {'num_workers': 4} if device.type == "cuda" else {}
+	if not testing:
+		if training:
+			if weighted:
+				weights = make_weights_for_balanced_classes_split(split_dataset)
+				loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate, **kwargs)
+			else:
+				loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate, **kwargs)
+		else:
+			loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate, **kwargs)
+
+	else:
+		ids = np.random.choice(np.arange(len(split_dataset)), int(len(split_dataset)*0.01), replace = False)
+		loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs )
+
+	return loader
+
+def get_optim(model, args):
+	if args.opt == "adam":
+		optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
+	elif args.opt == 'sgd':
+		optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
+	else:
+		raise NotImplementedError
+	return optimizer
+
+def print_network(net):
+	num_params = 0
+	num_params_train = 0
+	print(net)
+
+	for param in net.parameters():
+		n = param.numel()
+		num_params += n
+		if param.requires_grad:
+			num_params_train += n
+
+	print('Total number of parameters: %d' % num_params)
+	print('Total number of trainable parameters: %d' % num_params_train)
+
+
+def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
+	seed = 7, label_frac = 1.0, custom_test_ids = None):
+	indices = np.arange(samples).astype(int)
+
+	if custom_test_ids is not None:
+		indices = np.setdiff1d(indices, custom_test_ids)
+
+	np.random.seed(seed)
+	for i in range(n_splits):
+		all_val_ids = []
+		all_test_ids = []
+		sampled_train_ids = []
+
+		if custom_test_ids is not None: # pre-built test split, do not need to sample
+			all_test_ids.extend(custom_test_ids)
+
+		for c in range(len(val_num)):
+			if c == 38:
+				pdb.set_trace()
+			possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
+			remaining_ids = possible_indices
+
+			if val_num[c] > 0:
+				val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
+				remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
+				all_val_ids.extend(val_ids)
+
+			if custom_test_ids is None and test_num[c] > 0: # sample test split
+
+				test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
+				remaining_ids = np.setdiff1d(remaining_ids, test_ids)
+				all_test_ids.extend(test_ids)
+
+			if label_frac == 1:
+				sampled_train_ids.extend(remaining_ids)
+
+			else:
+				sample_num  = math.ceil(len(remaining_ids) * label_frac)
+				slice_ids = np.arange(sample_num)
+				sampled_train_ids.extend(remaining_ids[slice_ids])
+
+		yield sampled_train_ids, all_val_ids, all_test_ids
+
+
+def nth(iterator, n, default=None):
+	if n is None:
+		return collections.deque(iterator, maxlen=0)
+	else:
+		return next(islice(iterator,n, None), default)
+
+def calculate_error(Y_hat, Y):
+	error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
+
+	return error
+
+def make_weights_for_balanced_classes_split(dataset):
+	N = float(len(dataset))
+	weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]
+	weight = [0] * int(N)
+	for idx in range(len(dataset)):
+		y = dataset.getlabel(idx)
+		weight[idx] = weight_per_class[y]
+
+	return torch.DoubleTensor(weight)
+
+def initialize_weights(module):
+	for m in module.modules():
+		if isinstance(m, nn.Linear):
+			nn.init.xavier_normal_(m.weight)
+			m.bias.data.zero_()
+
+		elif isinstance(m, nn.BatchNorm1d):
+			nn.init.constant_(m.weight, 1)
+			nn.init.constant_(m.bias, 0)