--- a +++ b/utils/mm_regularization.py @@ -0,0 +1,214 @@ +""" +An implementation of the paper: "Removing Bias in Multi-modal Classifiers: Regularization by Maximizing Functional + Entropies" NeurIPS 2020. +""" + +import torch + + +class Perturbation: + """ + Class that in charge of the perturbation techniques + """ + @classmethod + def _add_noise_to_tensor(cls, tens: torch.Tensor, over_dim: int = 0) -> torch.Tensor: + """ + Adds noise to a tensor sampled from N(0, tens.std()). + :param tens: + :param over_dim: over what dim to calculate the std. 0 for features over batch, 1 for over sample. + :return: noisy tensor in the same shape as input + """ + + return tens + torch.randn_like(tens) * tens.std(dim=over_dim) + # return tens + torch.randn_like(tens) + + @classmethod + def perturb_tensor(cls, tens: torch.Tensor, n_samples: int, perturbation: bool = True) -> torch.Tensor: + """ + Flatting the tensor, expanding it, perturbing and reconstructing to the original shape. + Note, this function assumes that the batch is the first dimension. + :param tens: + :param n_samples: times to perturb + :param perturbation: False - only duplicating the tensor + :return: tensor in the shape of [batch, samples * num_eval_samples] + """ + tens_dim = list(tens.shape) + + tens = tens.view(tens.shape[0], -1) + tens = tens.repeat(1, n_samples) + + tens = tens.view(tens.shape[0] * n_samples, -1) + + if perturbation: + tens = cls._add_noise_to_tensor(tens) + + tens_dim[0] *= n_samples + + tens = tens.view(*tens_dim) + tens.requires_grad_() + + return tens + + @classmethod + def get_expanded_logits(cls, logits: torch.Tensor, n_samples: int, logits_flg: bool = True) -> torch.Tensor: + """ + Perform Softmax and then expand the logits depends on the num_eval_samples + :param logits_flg: whether the input is logits or softmax + :param logits: tensor holds logits outputs from the model + :param n_samples: times to duplicate + :return: + """ + if logits_flg: + logits = torch.nn.functional.softmax(logits, dim=1) + expanded_logits = logits.repeat(1, n_samples) + + return expanded_logits.view(expanded_logits.shape[0] * n_samples, -1) + + +class Regularization(object): + """ + Class that in charge of the regularization techniques + """ + @classmethod + def _get_variance(cls, loss: torch.Tensor) -> torch.Tensor: + """ + Computes the variance along samples for the first dimension in a tensor + :param loss: [batch, number of evaluate samples] + :return: variance of a given batch of loss values + """ + + return torch.var(loss, dim=1) + + @classmethod + def _get_differential_entropy(cls, loss: torch.Tensor) -> torch.Tensor: + """ + Computes differential entropy: -E[flogf] + :param loss: + :return: a tensor holds the differential entropy for a batch + """ + + return -1 * torch.sum(loss * loss.log()) + + @classmethod + def _get_functional_entropy(cls, loss: torch.Tensor) -> torch.Tensor: + """ + Computes functional entropy: E[flogf] - E[f]logE[f] + :param loss: + :return: a tensor holds the functional entropy for a batch + """ + loss = torch.nn.functional.normalize(loss, p=1, dim=1) + loss = torch.mean(loss * loss.log()) - (torch.mean(loss) * torch.mean(loss).log()) + + return loss + + @classmethod + def get_batch_statistics(cls, loss: torch.Tensor, n_samples: int, estimation: str = 'ent') -> torch.Tensor: + """ + Calculate the expectation of the batch gradient + :param n_samples: + :param loss: + :param estimation: + :return: Influence expectation + """ + loss = loss.reshape(-1, n_samples) + + if estimation == 'var': + batch_statistics = cls._get_variance(loss) + batch_statistics = torch.abs(batch_statistics) + elif estimation == 'ent': + batch_statistics = cls._get_functional_entropy(loss) + elif estimation == 'dif_ent': + batch_statistics = cls._get_differential_entropy(loss) + else: + raise NotImplementedError(f'{estimation} is unknown regularization, please use "var" or "ent".') + + return torch.mean(batch_statistics) + + @classmethod + def get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor: + """ + Calculate the expectation of the batch gradient + :param loss: + :param estimation: + :param grad: tensor holds the gradient batch + :return: approximation of the required expectation + """ + batch_grad_norm = torch.norm(grad, p=2, dim=1) + batch_grad_norm = torch.pow(batch_grad_norm, 2) + + if estimation == 'ent': + batch_grad_norm = batch_grad_norm / loss + + return torch.mean(batch_grad_norm) + + @classmethod + def _get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor: + """ + Calculate the expectation of the batch gradient + :param loss: + :param estimation: + :param grad: tensor holds the gradient batch + :return: approximation of the required expectation + """ + batch_grad_norm = torch.norm(grad, p=2, dim=1) + batch_grad_norm = torch.pow(batch_grad_norm, 2) + + if estimation == 'ent': + batch_grad_norm = batch_grad_norm / loss + + return batch_grad_norm + + @classmethod + def _get_max_ent(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor: + """ + Calculate the norm of 1 divided by the information + :param inf_scores: tensor holding batch information scores + :param norm: which norm to use + :return: + """ + return torch.norm(torch.div(1, inf_scores), p=norm) + + @classmethod + def _get_max_ent_minus(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor: + """ + Calculate -1 * the norm of the information + :param inf_scores: tensor holding batch information scores + :param norm: which norm to use + :return: + """ + return -1 * torch.norm(inf_scores, p=norm) + 0.1 + + @classmethod + def get_regularization_term(cls, inf_scores: torch.Tensor, norm: float = 2.0, + optim_method: str = 'max_ent') -> torch.Tensor: + """ + Compute the regularization term given a batch of information scores + :param inf_scores: tensor holding a batch of information scores + :param norm: defines which norm to use (1 or 2) + :param optim_method: Define optimization method (possible methods: "min_ent", "max_ent", "max_ent_minus", + "normalized") + :return: + """ + + if optim_method == 'max_ent': + return cls._get_max_ent(inf_scores, norm) + elif optim_method == 'min_ent': + return torch.norm(inf_scores, p=norm) + elif optim_method == 'max_ent_minus': + return cls._get_max_ent_minus(inf_scores, norm) + + raise NotImplementedError(f'"{optim_method}" is unknown') + + +class RegParameters(object): + """ + This class controls all the regularization-related properties + """ + def __init__(self, lambda_: float = 1e-10, norm: float = 2.0, estimation: str = 'ent', + optim_method: str = 'max_ent', n_samples: int = 10, grad: bool = True): + self.lambda_ = lambda_ + self.norm = norm + self.estimation = estimation + self.optim_method = optim_method + self.n_samples = n_samples + self.grad = grad \ No newline at end of file