Switch to side-by-side view

--- a
+++ b/utils/mm_regularization.py
@@ -0,0 +1,214 @@
+"""
+An implementation of the paper: "Removing Bias in Multi-modal Classifiers: Regularization by Maximizing Functional
+ Entropies" NeurIPS 2020.
+"""
+
+import torch
+
+
+class Perturbation:
+    """
+    Class that in charge of the perturbation techniques
+    """
+    @classmethod
+    def _add_noise_to_tensor(cls, tens: torch.Tensor, over_dim: int = 0) -> torch.Tensor:
+        """
+        Adds noise to a tensor sampled from N(0, tens.std()).
+        :param tens:
+        :param over_dim: over what dim to calculate the std. 0 for features over batch,  1 for over sample.
+        :return: noisy tensor in the same shape as input
+        """
+
+        return tens + torch.randn_like(tens) * tens.std(dim=over_dim)
+        # return tens + torch.randn_like(tens)
+
+    @classmethod
+    def perturb_tensor(cls, tens: torch.Tensor, n_samples: int, perturbation: bool = True) -> torch.Tensor:
+        """
+        Flatting the tensor, expanding it, perturbing and reconstructing to the original shape.
+        Note, this function assumes that the batch is the first dimension.
+        :param tens:
+        :param n_samples: times to perturb
+        :param perturbation: False - only duplicating the tensor
+        :return: tensor in the shape of [batch, samples * num_eval_samples]
+        """
+        tens_dim = list(tens.shape)
+
+        tens = tens.view(tens.shape[0], -1)
+        tens = tens.repeat(1, n_samples)
+
+        tens = tens.view(tens.shape[0] * n_samples, -1)
+
+        if perturbation:
+            tens = cls._add_noise_to_tensor(tens)
+
+        tens_dim[0] *= n_samples
+
+        tens = tens.view(*tens_dim)
+        tens.requires_grad_()
+
+        return tens
+
+    @classmethod
+    def get_expanded_logits(cls, logits: torch.Tensor, n_samples: int, logits_flg: bool = True) -> torch.Tensor:
+        """
+        Perform Softmax and then expand the logits depends on the num_eval_samples
+        :param logits_flg: whether the input is logits or softmax
+        :param logits: tensor holds logits outputs from the model
+        :param n_samples: times to duplicate
+        :return:
+        """
+        if logits_flg:
+            logits = torch.nn.functional.softmax(logits, dim=1)
+        expanded_logits = logits.repeat(1, n_samples)
+
+        return expanded_logits.view(expanded_logits.shape[0] * n_samples, -1)
+
+
+class Regularization(object):
+    """
+    Class that in charge of the regularization techniques
+    """
+    @classmethod
+    def _get_variance(cls, loss: torch.Tensor) -> torch.Tensor:
+        """
+        Computes the variance along samples for the first dimension in a tensor
+        :param loss: [batch, number of evaluate samples]
+        :return: variance of a given batch of loss values
+        """
+
+        return torch.var(loss, dim=1)
+
+    @classmethod
+    def _get_differential_entropy(cls, loss: torch.Tensor) -> torch.Tensor:
+        """
+        Computes differential entropy: -E[flogf]
+        :param loss:
+        :return: a tensor holds the differential entropy for a batch
+        """
+
+        return -1 * torch.sum(loss * loss.log())
+
+    @classmethod
+    def _get_functional_entropy(cls, loss: torch.Tensor) -> torch.Tensor:
+        """
+        Computes functional entropy: E[flogf] - E[f]logE[f]
+        :param loss:
+        :return: a tensor holds the functional entropy for a batch
+        """
+        loss = torch.nn.functional.normalize(loss, p=1, dim=1)
+        loss = torch.mean(loss * loss.log()) - (torch.mean(loss) * torch.mean(loss).log())
+
+        return loss
+
+    @classmethod
+    def get_batch_statistics(cls, loss: torch.Tensor, n_samples: int, estimation: str = 'ent') -> torch.Tensor:
+        """
+        Calculate the expectation of the batch gradient
+        :param n_samples:
+        :param loss:
+        :param estimation:
+        :return: Influence expectation
+        """
+        loss = loss.reshape(-1, n_samples)
+
+        if estimation == 'var':
+            batch_statistics = cls._get_variance(loss)
+            batch_statistics = torch.abs(batch_statistics)
+        elif estimation == 'ent':
+            batch_statistics = cls._get_functional_entropy(loss)
+        elif estimation == 'dif_ent':
+            batch_statistics = cls._get_differential_entropy(loss)
+        else:
+            raise NotImplementedError(f'{estimation} is unknown regularization, please use "var" or "ent".')
+
+        return torch.mean(batch_statistics)
+
+    @classmethod
+    def get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor:
+        """
+        Calculate the expectation of the batch gradient
+        :param loss:
+        :param estimation:
+        :param grad: tensor holds the gradient batch
+        :return: approximation of the required expectation
+        """
+        batch_grad_norm = torch.norm(grad, p=2, dim=1)
+        batch_grad_norm = torch.pow(batch_grad_norm, 2)
+
+        if estimation == 'ent':
+            batch_grad_norm = batch_grad_norm / loss
+
+        return torch.mean(batch_grad_norm)
+
+    @classmethod
+    def _get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor:
+        """
+        Calculate the expectation of the batch gradient
+        :param loss:
+        :param estimation:
+        :param grad: tensor holds the gradient batch
+        :return: approximation of the required expectation
+        """
+        batch_grad_norm = torch.norm(grad, p=2, dim=1)
+        batch_grad_norm = torch.pow(batch_grad_norm, 2)
+
+        if estimation == 'ent':
+            batch_grad_norm = batch_grad_norm / loss
+
+        return batch_grad_norm
+
+    @classmethod
+    def _get_max_ent(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor:
+        """
+        Calculate the norm of 1 divided by the information
+        :param inf_scores: tensor holding batch information scores
+        :param norm: which norm to use
+        :return:
+        """
+        return torch.norm(torch.div(1, inf_scores), p=norm)
+
+    @classmethod
+    def _get_max_ent_minus(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor:
+        """
+        Calculate -1 * the norm of the information
+        :param inf_scores: tensor holding batch information scores
+        :param norm: which norm to use
+        :return:
+        """
+        return -1 * torch.norm(inf_scores, p=norm) + 0.1
+
+    @classmethod
+    def get_regularization_term(cls, inf_scores: torch.Tensor, norm: float = 2.0,
+                                optim_method: str = 'max_ent') -> torch.Tensor:
+        """
+        Compute the regularization term given a batch of information scores
+        :param inf_scores: tensor holding a batch of information scores
+        :param norm: defines which norm to use (1 or 2)
+        :param optim_method: Define optimization method (possible methods: "min_ent", "max_ent", "max_ent_minus",
+         "normalized")
+        :return:
+        """
+
+        if optim_method == 'max_ent':
+            return cls._get_max_ent(inf_scores, norm)
+        elif optim_method == 'min_ent':
+            return torch.norm(inf_scores, p=norm)
+        elif optim_method == 'max_ent_minus':
+            return cls._get_max_ent_minus(inf_scores, norm)
+
+        raise NotImplementedError(f'"{optim_method}" is unknown')
+
+
+class RegParameters(object):
+    """
+    This class controls all the regularization-related properties
+    """
+    def __init__(self, lambda_: float = 1e-10, norm: float = 2.0, estimation: str = 'ent',
+                 optim_method: str = 'max_ent', n_samples: int = 10, grad: bool = True):
+        self.lambda_ = lambda_
+        self.norm = norm
+        self.estimation = estimation
+        self.optim_method = optim_method
+        self.n_samples = n_samples
+        self.grad = grad
\ No newline at end of file