--- a
+++ b/pathflowai/datasets.py
@@ -0,0 +1,616 @@
+"""
+datasets.py
+=======================
+Houses the DynamicImageDataset class, also functions to help with image color channel normalization, transformers, etc..
+"""
+
+import torch
+from torchvision import transforms
+import os
+import dask
+#from dask.distributed import Client; Client()
+import dask.array as da, pandas as pd, numpy as np
+from pathflowai.utils import *
+import pysnooper
+import nonechucks as nc
+from torch.utils.data import Dataset, DataLoader
+import random
+import albumentations as alb
+import copy
+from albumentations import pytorch as albtorch
+from sklearn.preprocessing import LabelBinarizer
+from sklearn.utils.class_weight import compute_class_weight
+from pathflowai.losses import class2one_hot
+import cv2
+from scipy.ndimage.morphology import generate_binary_structure
+from dask_image.ndmorph import binary_dilation
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+
+def RandomRotate90():
+	"""Transformer for random 90 degree rotation image.
+
+	Returns
+	-------
+	function
+		Transformer function for operation.
+
+	"""
+	return (lambda img: img.rotate(random.sample([0, 90, 180, 270], k=1)[0]))
+
+def get_data_transforms(patch_size = None, mean=[], std=[], resize=False, transform_platform='torch', elastic=True, user_transforms=dict()):
+	"""Get data transformers for training test and validation sets.
+
+	Parameters
+	----------
+	patch_size:int
+		Original patch size being transformed.
+	mean:list of float
+		Mean RGB
+	std:list of float
+		Std RGB
+	resize:int
+		Which patch size to resize to.
+	transform_platform:str
+		Use pytorch or albumentation transforms.
+	elastic:bool
+		Whether to add elastic deformations from albumentations.
+
+	Returns
+	-------
+	dict
+		Transformers.
+
+	"""
+	transform_dict=dict(torch=dict(
+									colorjitter=lambda kargs: transforms.ColorJitter(**kargs),
+									hflip=lambda kargs: transforms.RandomHorizontalFlip(),
+									vflip=lambda kargs: transforms.RandomVerticalFlip(),
+									r90= lambda kargs: RandomRotate90()
+									),
+						albumentations=dict(
+							huesaturation=lambda kargs: alb.augmentations.transforms.HueSaturationValue(**kargs),
+							flip=lambda kargs: alb.augmentations.transforms.Flip(**kargs),
+							transpose=lambda kargs: alb.augmentations.transforms.Transpose(**kargs),
+							affine=lambda kargs: alb.augmentations.transforms.ShiftScaleRotate(**kargs),
+							r90=lambda kargs: alb.augmentations.transforms.RandomRotate90(**kargs),
+							elastic=lambda kargs: alb.augmentations.transforms.ElasticTransform(**kargs)
+						))
+	if 'normalization' in user_transforms:
+		mean=user_transforms['normalization'].pop('mean')
+		std=user_transforms['normalization'].pop('std')
+		del user_transforms['normalization']
+	default_transforms=dict() # add normalization custom
+	default_transforms['torch']=dict(
+							colorjitter=dict(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.5),
+							hflip=dict(),
+							vflip=dict(),
+							r90=dict())
+	default_transforms['albumentations']=dict(
+							huesaturation=dict(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
+							r90=dict(p=0.5),
+							elastic=dict(p=0.5))
+	main_transforms = default_transforms[transform_platform] if not user_transforms else user_transforms
+	print(main_transforms)
+	train_transforms=[transform_dict[transform_platform][k](v) for k,v in main_transforms.items()]
+	torch_init=[transforms.ToPILImage(),transforms.Resize((patch_size,patch_size)),transforms.CenterCrop(patch_size)]
+	albu_init=[alb.augmentations.transforms.Resize(patch_size, patch_size),
+				alb.augmentations.transforms.CenterCrop(patch_size, patch_size)]
+	tensor_norm=[transforms.ToTensor(),transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])] #mean and standard deviations for lung adenocarcinoma resection slides
+	data_transforms = { 'torch': {
+		'train': transforms.Compose(torch_init+train_transforms+tensor_norm),
+		'val': transforms.Compose([
+			transforms.ToPILImage(),
+			transforms.Resize((patch_size,patch_size)),
+			transforms.CenterCrop(patch_size),
+			transforms.ToTensor(),
+			transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])
+		]),
+		'test': transforms.Compose([
+			transforms.ToPILImage(),
+			transforms.Resize((patch_size,patch_size)),
+			transforms.CenterCrop(patch_size),
+			transforms.ToTensor(),
+			transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])
+		]),
+		'pass': transforms.Compose([
+			transforms.ToPILImage(),
+			transforms.CenterCrop(patch_size),
+			transforms.ToTensor(),
+		])
+	},
+	'albumentations':{
+	'train':alb.core.composition.Compose(albu_init+train_transforms),
+	'val':alb.core.composition.Compose([
+		alb.augmentations.transforms.Resize(patch_size, patch_size),
+		alb.augmentations.transforms.CenterCrop(patch_size, patch_size)
+	]),
+	'test':alb.core.composition.Compose([
+		alb.augmentations.transforms.Resize(patch_size, patch_size),
+		alb.augmentations.transforms.CenterCrop(patch_size, patch_size)
+	]),
+	'normalize':transforms.Compose([transforms.Normalize(mean if mean else [0.7, 0.6, 0.7], std if std is not None else [0.15, 0.15, 0.15])])
+	}}
+
+	return data_transforms[transform_platform]
+
+def create_transforms(mean, std):
+	"""Create transformers.
+
+	Parameters
+	----------
+	mean:list
+		See get_data_transforms.
+	std:list
+		See get_data_transforms.
+
+	Returns
+	-------
+	dict
+		Transformers.
+
+	"""
+	return get_data_transforms(patch_size = 224, mean=mean, std=std, resize=True)
+
+
+
+def get_normalizer(normalization_file, dataset_opts):
+	"""Find mean and standard deviation of images in batches.
+
+	Parameters
+	----------
+	normalization_file:str
+		File to store normalization information.
+	dataset_opts:type
+		Dictionary storing information to create DynamicDataset class.
+
+	Returns
+	-------
+	dict
+		Stores RGB mean, stdev.
+
+	"""
+	if os.path.exists(normalization_file):
+		norm_dict = torch.load(normalization_file)
+	else:
+		norm_dict = {'normalization_file':normalization_file}
+
+	if 'normalization_file' in norm_dict:
+
+		transformers = get_data_transforms(patch_size = 224, mean=[], std=[], resize=True, transform_platform='torch')
+
+		dataset_opts['transformers']=transformers
+		#print(dict(pos_annotation_class=pos_annotation_class, segmentation=segmentation, patch_size=patch_size, fix_names=fix_names, other_annotations=other_annotations))
+
+		dataset = DynamicImageDataset(**dataset_opts)#nc.SafeDataset(DynamicImageDataset(**dataset_opts))
+
+		if dataset_opts['classify_annotations']:
+			dataset.binarize_annotations()
+
+		dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
+
+		all_mean = torch.tensor([0.,0.,0.],dtype=torch.float)#[]
+
+		all_std = torch.tensor([0.,0.,0.],dtype=torch.float)
+
+		if torch.cuda.is_available():
+			all_mean=all_mean.cuda()
+			all_std=all_std.cuda()
+
+		with torch.no_grad():
+			for i,(X,_) in enumerate(dataloader): # x,3,224,224
+				if torch.cuda.is_available():
+					X=X.cuda()
+				all_mean += torch.mean(X, (0,2,3))
+				all_std += torch.std(X, (0,2,3))
+
+		N=i+1
+
+		all_mean /= float(N) #(np.array(all_mean).mean(axis=0)).tolist()
+		all_std /= float(N) #(np.array(all_std).mean(axis=0)).tolist()
+
+		all_mean = all_mean.detach().cpu().numpy().tolist()
+		all_std = all_std.detach().cpu().numpy().tolist()
+
+		torch.save(dict(mean=all_mean,std=all_std),norm_dict['normalization_file'])
+
+		norm_dict = torch.load(norm_dict['normalization_file'])
+	return norm_dict
+
+def segmentation_transform(img,mask, transformer, normalizer, alb_reduction):
+	"""Run albumentations and return an image and its segmentation mask.
+
+	Parameters
+	----------
+	img:array
+		Image as array
+	mask:array
+		Categorical pixel by pixel.
+	transformer :
+		Transformation object.
+
+	Returns
+	-------
+	tuple arrays
+		Image and mask array.
+
+	"""
+	res=transformer(True, image=img, mask=mask)
+	#res_mask_shape = res['mask'].size()
+	return normalizer(torch.tensor(np.transpose(res['image']/alb_reduction,axes=(2,0,1)),dtype=torch.float)).float(), torch.tensor(res['mask']).long()#.view(res_mask_shape[0],res_mask_shape[1],res_mask_shape[2])
+
+class DilationJitter:
+	def __init__(self, dilation_jitter=dict(), segmentation=True, train_set=False):
+		if dilation_jitter and segmentation and train_set:
+			self.run_jitter=True
+			self.dilation_jitter=dilation_jitter
+			self.struct=generate_binary_structure(2,1) #structure=self.struct,
+		else:
+			self.run_jitter=False
+
+
+	def __call__(self, mask):
+		if self.run_jitter:
+			for k in self.dilation_jitter:
+				amount_jitter=int(round(max(np.random.normal(self.dilation_jitter[k]['mean'],
+														self.dilation_jitter[k]['std']),1)))
+				#print((mask==k).compute())
+				mask[binary_dilation(mask==k,structure=self.struct,iterations=amount_jitter)]=k
+
+		return mask
+
+
+class DynamicImageDataset(Dataset):
+	"""Generate image dataset that accesses images and annotations via dask.
+
+	Parameters
+	----------
+	dataset_df:dataframe
+		Dataframe with WSI, which set it is in (train/test/val) and corresponding WSI labels if applicable.
+	set:str
+		Whether train, test, val or pass (normalization) set.
+	patch_info_file:str
+		SQL db with positional and annotation information on each slide.
+	transformers:dict
+		Contains transformers to apply on images.
+	input_dir:str
+		Directory where images comes from.
+	target_names:list/str
+		Names of initial targets, which may be modified.
+	pos_annotation_class:str
+		If selected and predicting on WSI, this class is labeled as a positive from the WSI, while the other classes are not.
+	other_annotations:list
+		Other annotations to consider from patch info db.
+	segmentation:bool
+		Conducting segmentation task?
+	patch_size:int
+		Patch size.
+	fix_names:bool
+		Whether to change the names of dataset_df.
+	target_segmentation_class:list
+		Now can be used for classification as well, matched with two below options, samples images only from this class. Can specify this and below two options multiple times.
+	target_threshold:list
+		Sampled only if above this threshold of occurence in the patches.
+	oversampling_factor:list
+		Over sample them at this amount.
+	n_segmentation_classes:int
+		Number classes to segment.
+	gdl:bool
+		Using generalized dice loss?
+	mt_bce:bool
+		For multi-target prediction tasks.
+	classify_annotations:bool
+		For classifying annotations.
+
+	"""
+	# when building transformers, need a resize patch size to make patches 224 by 224
+	#@pysnooper.snoop('init_data.log')
+	def __init__(self,dataset_df, set, patch_info_file, transformers, input_dir, target_names, pos_annotation_class, other_annotations=[], segmentation=False, patch_size=224, fix_names=True, target_segmentation_class=-1, target_threshold=0., oversampling_factor=1., n_segmentation_classes=4, gdl=False, mt_bce=False, classify_annotations=False, dilation_jitter=dict(), modify_patches=True):
+
+		#print('check',classify_annotations)
+		reduce_alb=True
+		self.patch_size=patch_size
+		self.input_dir = input_dir
+		self.alb_reduction=255. if reduce_alb else 1.
+		self.transformer=transformers[set]
+		original_set = copy.deepcopy(set)
+		if set=='pass':
+			set='train'
+		self.targets = target_names
+		self.mt_bce=mt_bce
+		self.set = set
+		self.segmentation = segmentation
+		self.alb_normalizer=None
+		if 'normalize' in transformers:
+			self.alb_normalizer = transformers['normalize']
+		if len(self.targets)==1:
+			self.targets = self.targets[0]
+		if original_set == 'pass':
+			self.transform_fn = lambda x,y: (self.transformer(x), torch.tensor(1.,dtype=torch.float))
+		else:
+			if self.segmentation:
+				self.transform_fn = lambda x,y: segmentation_transform(x,y, self.transformer, self.alb_normalizer, self.alb_reduction)
+			else:
+				if 'p' in dir(self.transformer):
+					self.transform_fn = lambda x,y: (self.alb_normalizer(torch.tensor(np.transpose(self.transformer(True, image=x)['image']/self.alb_reduction,axes=(2,0,1)),dtype=torch.float)), torch.from_numpy(y).float())
+				else:
+					self.transform_fn = lambda x,y: (self.transformer(x), torch.from_numpy(y).float())
+		self.image_set = dataset_df[dataset_df['set']==set]
+		if self.segmentation:
+			self.targets='target'
+			self.image_set[self.targets] = 1.
+		if not self.segmentation and fix_names:
+			self.image_set.loc[:,'ID'] = self.image_set['ID'].map(fix_name)
+		self.slide_info = pd.DataFrame(self.image_set.set_index('ID').loc[:,self.targets])
+		if self.mt_bce and not self.segmentation:
+			if pos_annotation_class:
+				self.targets = [pos_annotation_class]+list(other_annotations)
+			else:
+				self.targets = None
+		print(self.targets)
+		IDs = self.slide_info.index.tolist()
+		pi_dict=dict(input_info_db=patch_info_file,
+					slide_labels=self.slide_info,
+					pos_annotation_class=pos_annotation_class,
+					patch_size=patch_size,
+					segmentation=self.segmentation,
+					other_annotations=other_annotations,
+					target_segmentation_class=target_segmentation_class,
+					target_threshold=target_threshold,
+					classify_annotations=classify_annotations,
+					modify_patches=modify_patches)
+		self.patch_info = modify_patch_info(**pi_dict)
+
+		if self.segmentation and original_set!='pass':
+			#IDs = self.patch_info['ID'].unique()
+			self.segmentation_maps = {slide:npy2da(join(input_dir,'{}_mask.npy'.format(slide))) for slide in IDs}
+		self.slides = {slide:load_preprocessed_img(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
+		#print(self.slide_info)
+		if original_set =='pass':
+			self.segmentation=False
+		#print(self.patch_info[self.targets].unique())
+		if oversampling_factor > 1:
+			self.patch_info = pd.concat([self.patch_info]*int(oversampling_factor),axis=0).reset_index(drop=True)
+		elif oversampling_factor < 1:
+			self.patch_info = self.patch_info.sample(frac=oversampling_factor).reset_index(drop=True)
+		self.length = self.patch_info.shape[0]
+		self.n_segmentation_classes = n_segmentation_classes
+		self.gdl=gdl if self.segmentation else False
+		self.binarized=False
+		self.classify_annotations=classify_annotations
+		print(self.targets)
+		self.dilation_jitter=DilationJitter(dilation_jitter,self.segmentation,(original_set=='train'))
+		if not self.targets:
+			self.targets = [pos_annotation_class]+list(other_annotations)
+
+	def concat(self, other_dataset):
+		"""Concatenate this dataset with others. Updates its own internal attributes.
+
+		Parameters
+		----------
+		other_dataset:DynamicImageDataset
+			Other image dataset.
+
+		"""
+		self.patch_info = pd.concat([self.patch_info, other_dataset.patch_info],axis=0).reset_index(drop=True)
+		self.length = self.patch_info.shape[0]
+		if self.segmentation:
+			self.segmentation_maps.update(other_dataset.segmentation_maps)
+			#print(self.segmentation_maps.keys())
+
+	def retain_ID(self, ID):
+		"""Reduce the sample set to just images from one ID.
+
+		Parameters
+		----------
+		ID:str
+			Basename/ID to predict on.
+
+		Returns
+		-------
+		self
+
+		"""
+		self.patch_info=self.patch_info.loc[self.patch_info['ID']==ID]
+		self.length = self.patch_info.shape[0]
+		self.segmentation_maps={ID:self.segmentation_maps[ID]}
+		return self
+
+	def split_by_ID(self):
+		"""Generator similar to groupby, but splits up by ID, generates (ID,data) using retain_ID.
+
+		Returns
+		-------
+		generator
+			ID, DynamicDataset
+
+		"""
+		for ID in self.patch_info['ID'].unique():
+			new_dataset = copy.deepcopy(self)
+			yield ID, new_dataset.retain_ID(ID)
+
+	def select_IDs(self, IDs):
+		for ID in IDs:
+			if ID in self.patch_info['ID'].unique():
+				new_dataset = copy.deepcopy(self)
+				yield ID, new_dataset.retain_ID(ID)
+
+
+	def get_class_weights(self, i=0):#[0,1]
+		"""Weight loss function with weights inversely proportional to the class appearence.
+
+		Parameters
+		----------
+		i:int
+			If multi-target, class used for weighting.
+
+		Returns
+		-------
+		self
+			Dataset.
+
+		"""
+		if self.segmentation:
+			label_counts=self.patch_info[list(map(str,list(range(self.n_segmentation_classes))))].sum(axis=0).values
+			freq = label_counts/sum(label_counts)
+			weights=1./(freq)
+		elif self.mt_bce:
+			weights=1./(self.patch_info.loc[:,self.targets].sum(axis=0).values)
+			weights=weights/sum(weights)
+		else:
+			if self.binarized and len(self.targets)>1:
+				y=np.argmax(self.patch_info.loc[:,self.targets].values,axis=1)
+			elif (type(self.targets)==type('')):
+				y=self.patch_info.loc[:,self.targets]
+			else:
+				y=self.patch_info.loc[:,self.targets[i]]
+			y=y.values.astype(int).flatten()
+			weights=compute_class_weight(class_weight='balanced',classes=np.unique(y),y=y)
+		return weights
+
+	def binarize_annotations(self, binarizer=None, num_targets=1, binary_threshold=0.):
+		"""Label binarize some annotations or threshold them if classifying slide annotations.
+
+		Parameters
+		----------
+		binarizer:LabelBinarizer
+			Binarizes the labels of a column(s)
+		num_targets:int
+			Number of desired targets to preidict on.
+		binary_threshold:float
+			Amount of annotation in patch before positive annotation.
+
+		Returns
+		-------
+		binarizer
+
+		"""
+
+		annotations = self.patch_info['annotation']
+		annots=[annot for annot in list(self.patch_info.iloc[:,6:]) if annot !='area']
+		if not self.mt_bce and num_targets > 1:
+			if binarizer == None:
+				self.binarizer = LabelBinarizer().fit(annotations)
+			else:
+				self.binarizer = copy.deepcopy(binarizer)
+			self.targets = self.binarizer.classes_
+			annotation_labels = pd.DataFrame(self.binarizer.transform(annotations),index=self.patch_info.index,columns=self.targets).astype(float)
+			for col in list(annotation_labels):
+				if col in list(self.patch_info):
+					self.patch_info.loc[:,col]=annotation_labels[col].values
+				else:
+					self.patch_info[col]=annotation_labels[col].values
+		else:
+			self.binarizer=None
+			self.targets=annots
+			if num_targets == 1:
+				self.targets = [self.targets[-1]]
+			if binary_threshold>0.:
+				self.patch_info.loc[:,self.targets]=(self.patch_info[self.targets]>=binary_threshold).values.astype(np.float32)
+			print(self.targets)
+			#self.patch_info = pd.concat([self.patch_info,annotation_labels],axis=1)
+		self.binarized=True
+		return self.binarizer
+
+	def subsample(self, p):
+		"""Sample subset of dataset.
+
+		Parameters
+		----------
+		p:float
+			Fraction to subsample.
+
+		"""
+		np.random.seed(42)
+		self.patch_info = self.patch_info.sample(frac=p)
+		self.length = self.patch_info.shape[0]
+
+	def update_dataset(self, input_dir, new_db, prediction_basename=[]):
+		"""Experimental. Only use for segmentation for now."""
+		self.input_dir=input_dir
+		self.patch_info=load_sql_df(new_db, self.patch_size)
+		IDs = self.patch_info['ID'].unique()
+		self.slides = {slide:load_preprocessed_img(join(self.input_dir,'{}.zarr'.format(slide))) for slide in IDs}
+		if self.segmentation:
+			if prediction_basename:
+				self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs if slide in prediction_basename}
+			else:
+				self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs}
+		self.length = self.patch_info.shape[0]
+
+	#@pysnooper.snoop("getitem.log")
+	def __getitem__(self, i):
+		patch_info = self.patch_info.iloc[i]
+		ID = patch_info['ID']
+		xs = patch_info['x']
+		ys = patch_info['y']
+		patch_size = patch_info['patch_size']
+		if xs==np.nan:
+			entire_image=True
+		else:
+			entire_image=False
+		targets=self.targets
+		use_long=False
+		if not self.segmentation:
+			y = patch_info.loc[list(self.targets) if not isinstance(self.targets,str) else self.targets]
+			if isinstance(y,pd.Series):
+				y=y.values.astype(float)
+				if self.binarized and not self.mt_bce and len(y)>1:
+					y=np.array(y.argmax())
+					use_long=True
+			y=np.array(y)
+			if not y.shape:
+				y=y.reshape(1)
+		if self.segmentation:
+			arr=self.segmentation_maps[ID]
+			if not entire_image:
+				arr=arr[xs:xs+patch_size,ys:ys+patch_size]
+			arr=self.dilation_jitter(arr)
+		y=(y if not self.segmentation else np.array(arr))
+		#print(y)
+		arr=self.slides[ID]
+		if not entire_image:
+			arr=arr[xs:xs+patch_size,ys:ys+patch_size,:3]
+		image, y = self.transform_fn(arr.compute().astype(np.uint8), y)#.unsqueeze(0) # transpose .transpose([1,0,2])
+		if not self.segmentation and not self.mt_bce and self.classify_annotations and use_long:
+			y=y.long()
+		#image_size=image.size()
+		if self.gdl:
+			y=class2one_hot(y, self.n_segmentation_classes)
+		#	y=one_hot2dist(y)
+		return image, y
+
+	def __len__(self):
+		return self.length
+
+class NPYDataset(Dataset):
+	def __init__(self, patch_info, patch_size, npy_file, transform, mmap=False):
+		self.ID=os.path.basename(npy_file).split('.')[0]
+		patch_info=patch_info=load_sql_df(patch_info,patch_size)
+		self.patch_info=patch_info.loc[patch_info["ID"]==self.ID].reset_index()
+		self.X=np.load(npy_file,mmap_mode=(None if not mmap else 'r+'))
+		self.transform=transform
+
+	def __getitem__(self,i):
+		x,y,patch_size=self.patch_info.loc[i,["x","y","patch_size"]]
+		return self.transform(self.X[x:x+patch_size,y:y+patch_size])
+
+	def __len__(self):
+		return self.patch_info.shape[0]
+
+	def embed(self,model,batch_size,out_dir):
+		Z=[]
+		dataloader=DataLoader(self,batch_size=batch_size,shuffle=False)
+		n_batches=len(self)//batch_size
+		with torch.no_grad():
+			for i,X in enumerate(dataloader):
+				if torch.cuda.is_available():
+					X=X.cuda()
+				z=model(X).detach().cpu().numpy()
+				Z.append(z)
+				print(f"Processed batch {i}/{n_batches}")
+		Z=np.vstack(Z)
+		torch.save(dict(embeddings=Z,patch_info=self.patch_info),os.path.join(out_dir,f"{self.ID}.pkl"))
+		print("Embeddings saved")
+		quit()