--- a
+++ b/metrics.py
@@ -0,0 +1,257 @@
+import numpy as np
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from sklearn.utils.extmath import cartesian
+from hausdorff import hausdorff_distance
+
+__all__ = ['Dice loss', 'Cross entropy', 'Focal loss', 'Dice Iou Cross entropy', 'Binary dice loss']
+
+
+class IOU(nn.Module):
+    '''
+    Calculate Intersection over Union (IoU) for semantic segmentation.
+    
+    Args:
+        logits (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth))
+        target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth))
+        num_classes (int): Number of classes
+
+    Returns:
+        tensor: Mean Intersection over Union (IoU) for the batch.
+        list: List of IOU score for each class
+    '''
+    def __init__(self, num_classes, ignore_index=[0]):
+        super(IOU, self).__init__()
+        self.num_classes = num_classes
+        self.ignore_index = ignore_index
+        
+    def forward(self, logits, target):
+        pred = logits.argmax(dim=1)        
+        target = target.argmax(dim=1)       
+        ious = []
+        for cls in range(self.num_classes):
+            if cls in self.ignore_index: continue
+            pred_mask = (pred == cls)
+            target_mask = (target == cls)
+                            
+            intersection = (pred_mask & target_mask).sum().float()
+            union = (pred_mask | target_mask).sum().float()
+            
+            if union == 0: iou = 1.0 
+            else: iou = (intersection / union).item()
+            ious.append(iou)
+        
+        mean_iou = sum(ious) / (self.num_classes - len(self.ignore_index))
+        return torch.tensor(mean_iou), ious
+
+    
+class BinaryDice(nn.Module):
+    '''
+    Calculate Binary Dice score and Dice loss for binary segmentation or each class in Multiclass segmentation
+    
+    Args:
+        logits (torch.Tensor): Predicted tensor of shape (batch_size, height, width, (depth))
+        target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width. (depth))
+        
+    Returns:
+        tensor: Dice score
+        tensor: Dice loss
+    '''
+    def __init__(self, smooth=1e-5, p=2):
+        super(BinaryDice, self).__init__()
+        self.smooth = smooth
+        self.p = p
+
+    def forward(self, logits, target):
+        assert logits.shape[0] == target.shape[0], "logits & Target batch size don't match"
+        smooth = 1e-5
+        intersect = torch.sum(logits * target)        
+        y_sum = torch.sum(target * target)
+        z_sum = torch.sum(logits * logits)
+        dice = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
+        loss = 1 - dice
+        return dice, loss
+        
+
+class Dice(nn.Module):
+    '''
+    Calculate Dice score and Dice loss for multiclass semantic segmentation
+    
+    Args:
+        output (torch.Tensor): Predicted tensor of shape (batch_size, num_classes, height, width, (depth))
+        target (torch.Tensor): Ground truth tensor of shape (batch_size, height, width, (depth))
+        num_classes (int): Number of classes 
+        
+    Returns:
+        tensor: Mean dice score over classes
+        tensor: Mean dice loss over classes
+        list: dice score for each classes
+        listL dice loss for each classes
+    '''
+    def __init__(self, num_classes, weight=None, softmax=True, ignore_index=[0]):
+        super(Dice, self).__init__()
+        self.num_classes = num_classes
+        self.weight = weight
+        self.softmax = softmax
+        self.ignore_index = ignore_index
+        self.binary_dice = BinaryDice()
+
+    def forward(self, logits, target):
+        assert logits.shape == target.shape, 'logits & Target shape do not match'
+        if self.softmax: logits = F.softmax(logits, dim=1)
+        
+        DICE, LOSS = 0.0, 0.0
+        CLS_DICE, CLS_LOSS = [], []
+        for clx in range(target.shape[1]):
+            if clx in self.ignore_index: continue
+            dice, loss = self.binary_dice(logits[:, clx], target[:, clx])
+            CLS_DICE.append(dice.item())
+            CLS_LOSS.append(loss.item())
+            if self.weight is not None: dice *= self.weights[clx]
+            DICE += dice
+            LOSS += loss
+
+        num_valid_classes = self.num_classes - len(self.ignore_index)
+        return DICE / num_valid_classes, LOSS / num_valid_classes, CLS_DICE, CLS_LOSS
+    
+    
+class WeightedHausdorffDistance(nn.Module):
+    def __init__(self, height, width, p=-9, return_2_terms=False, device=torch.device('cuda')):
+        '''
+        height         (int):  image height
+        width          (int):  image width
+        return_2_terms (bool): Whether to return the 2 terms
+                               of the WHD instead of their sum.
+        '''
+        super().__init__()
+        self.height, self.width = height, width
+        self.size = torch.tensor([height, width], dtype=torch.get_default_dtype(), device=device)
+        self.max_dist = math.sqrt(height**2 + width**2)
+        self.n_pixels = height * width
+        self.all_img_locations = torch.from_numpy(cartesian([np.arange(height), np.arange(width)]))
+        self.all_img_locations = self.all_img_locations.to(device=device, dtype=torch.get_default_dtype())
+        self.return_2_terms = return_2_terms
+        self.p = p
+        
+    def _assert_no_grad(self, variables):
+        for var in variables:
+            assert not var.requires_grad, \
+                "nn criterions don't compute the gradient w.r.t. targets - please " \
+                "mark these variables as volatile or not requiring gradients"
+                
+    def cdist(self, x, y):
+        '''
+        Compute distance between each pair of the two collections of inputs.
+        x: Nxd Tensor
+        y: Mxd Tensor
+        return: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:]
+                i.e. dist[i,j] = || x[i,:] - y[j,:] ||
+        '''
+        difs = x.unsqueeze(1) - y.unsqueeze(0)
+        dists = torch.sum(difs**2, -1).sqrt()
+        return dists
+    
+    def generalize_mean(self, tensor, dim, p=-9, keepdim=False):
+        assert p < 0
+        res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p)
+        return res
+        
+    def forward(self, prob_map, gt, orig_sizes):
+        '''
+        prob_map: (B x H x W) Tensor of the probability map of the estimation.
+                              B is batch size, H is height and W is width.
+                              Values must be between 0 and 1.
+                              
+        gt: List of Tensors of the Ground Truth points.
+            Must be of size B as in prob_map.
+            Each element in the list must be a 2D Tensor,
+            where each row is the (y, x), i.e, (row, col) of a GT point.
+        
+        orig_sizes: Bx2 Tensor containing the size
+                    of the original images.
+                    B is batch size.
+                    The size must be in (height, width) format. 
+                    
+        return: Single-scalar Tensor with the Weighted Hausdorff Distance.
+                If self.return_2_terms=True, then return a tuple containing
+                the two terms of the Weighted Hausdorff Distance.
+        '''
+        
+        self._assert_no_grad(gt)
+        assert prob_map.dim() == 3, 'The probability map must be (B x H x W)'
+        assert prob_map.size()[1:3] == (self.height, self.width), \
+            'You must configure the WeightedHausdorffDistance with the height and width of the ' \
+            'probability map that you are using, got a probability map of size %s'\
+            % str(prob_map.size())
+            
+        batch_size = prob_map.shape[0]
+        assert batch_size == len(gt)
+        
+        terms_1 = []
+        terms_2 = []
+        for b in range(batch_size):
+            
+            # One by one
+            prob_map_b = prob_map[b, :, :]
+            gt_b = gt[b]
+            orig_size_b = orig_sizes[b, :]
+            norm_factor = (orig_size_b / self.size).unsqueeze(0)
+            n_gt_pts = gt_b.size()[0]
+            
+            # Corner case: no GT points
+            if gt_b.ndimension() == 1 and (gt_b < 0).all().item() == 0:
+                terms_1.append(torch.tensor([0], 
+                                            dtype=torch.get_default_dtype()))
+                terms_2.append(torch.tensor([self.max_dist],
+                                            dtype=torch.get_default_dtype())) 
+                continue
+            
+            # Pairwise distances between all possible locations and the GTed locations
+            n_gt_pts = gt_b.size()[0]
+            normalized_x = norm_factor.repeat(self.n_pixels, 1) * self.all_img_locations
+            normalized_y = norm_factor.repeat(len(gt_b), 1) * gt_b
+            d_matrix = self.cdist(normalized_x, normalized_y)
+            
+            # Reshape probability map as a long column vector
+            # and prepare it for mulitplication
+            p = prob_map_b.view(prob_map_b.nelement())
+            n_est_pts = p.sum()
+            p_replicated = p.view(-1, 1).repeat(1, n_gt_pts)
+            
+            # Weighted Hausdorff Distance
+            term_1 = (1 / (n_est_pts + 1e-6)) * torch.sum(p * torch.min(d_matrix, 1)[0])
+            weighted_d_matrix = (1 - p_replicated)*self.max_dist + p_replicated*d_matrix
+            minn = self.generalize_mean(weighted_d_matrix,
+                                  p=self.p,
+                                  dim=0, keepdim=False)
+            term_2 = torch.mean(minn)
+
+            terms_1.append(term_1)
+            terms_2.append(term_2)
+            
+        terms_1 = torch.stack(terms_1)
+        terms_2 = torch.stack(terms_2)
+        
+        if self.return_2_terms: res = terms_1.mean(), terms_2.means()
+        else: res = terms_1.mean() + terms_2.mean()
+        return res
+    
+
+class HD(nn.Module):
+    def __init__(self):
+        super().__init__()
+        
+    def forward(self, logits, target):
+        _,logits = torch.max(logits, dim=1)
+        _,target = torch.max(target, dim=1)
+        
+        logits = logits.detach().cpu().numpy()
+        target = target.detach().cpu().numpy()
+        
+        hd = 0
+        for index in range(logits.shape[0]):
+            hd += hausdorff_distance(logits[index], target[index], distance='euclidean')
+        
+        return hd / logits.shape[0]
\ No newline at end of file