Switch to unified view

a b/utils/mm_regularization.py
1
"""
2
An implementation of the paper: "Removing Bias in Multi-modal Classifiers: Regularization by Maximizing Functional
3
 Entropies" NeurIPS 2020.
4
"""
5
6
import torch
7
8
9
class Perturbation:
10
    """
11
    Class that in charge of the perturbation techniques
12
    """
13
    @classmethod
14
    def _add_noise_to_tensor(cls, tens: torch.Tensor, over_dim: int = 0) -> torch.Tensor:
15
        """
16
        Adds noise to a tensor sampled from N(0, tens.std()).
17
        :param tens:
18
        :param over_dim: over what dim to calculate the std. 0 for features over batch,  1 for over sample.
19
        :return: noisy tensor in the same shape as input
20
        """
21
22
        return tens + torch.randn_like(tens) * tens.std(dim=over_dim)
23
        # return tens + torch.randn_like(tens)
24
25
    @classmethod
26
    def perturb_tensor(cls, tens: torch.Tensor, n_samples: int, perturbation: bool = True) -> torch.Tensor:
27
        """
28
        Flatting the tensor, expanding it, perturbing and reconstructing to the original shape.
29
        Note, this function assumes that the batch is the first dimension.
30
        :param tens:
31
        :param n_samples: times to perturb
32
        :param perturbation: False - only duplicating the tensor
33
        :return: tensor in the shape of [batch, samples * num_eval_samples]
34
        """
35
        tens_dim = list(tens.shape)
36
37
        tens = tens.view(tens.shape[0], -1)
38
        tens = tens.repeat(1, n_samples)
39
40
        tens = tens.view(tens.shape[0] * n_samples, -1)
41
42
        if perturbation:
43
            tens = cls._add_noise_to_tensor(tens)
44
45
        tens_dim[0] *= n_samples
46
47
        tens = tens.view(*tens_dim)
48
        tens.requires_grad_()
49
50
        return tens
51
52
    @classmethod
53
    def get_expanded_logits(cls, logits: torch.Tensor, n_samples: int, logits_flg: bool = True) -> torch.Tensor:
54
        """
55
        Perform Softmax and then expand the logits depends on the num_eval_samples
56
        :param logits_flg: whether the input is logits or softmax
57
        :param logits: tensor holds logits outputs from the model
58
        :param n_samples: times to duplicate
59
        :return:
60
        """
61
        if logits_flg:
62
            logits = torch.nn.functional.softmax(logits, dim=1)
63
        expanded_logits = logits.repeat(1, n_samples)
64
65
        return expanded_logits.view(expanded_logits.shape[0] * n_samples, -1)
66
67
68
class Regularization(object):
69
    """
70
    Class that in charge of the regularization techniques
71
    """
72
    @classmethod
73
    def _get_variance(cls, loss: torch.Tensor) -> torch.Tensor:
74
        """
75
        Computes the variance along samples for the first dimension in a tensor
76
        :param loss: [batch, number of evaluate samples]
77
        :return: variance of a given batch of loss values
78
        """
79
80
        return torch.var(loss, dim=1)
81
82
    @classmethod
83
    def _get_differential_entropy(cls, loss: torch.Tensor) -> torch.Tensor:
84
        """
85
        Computes differential entropy: -E[flogf]
86
        :param loss:
87
        :return: a tensor holds the differential entropy for a batch
88
        """
89
90
        return -1 * torch.sum(loss * loss.log())
91
92
    @classmethod
93
    def _get_functional_entropy(cls, loss: torch.Tensor) -> torch.Tensor:
94
        """
95
        Computes functional entropy: E[flogf] - E[f]logE[f]
96
        :param loss:
97
        :return: a tensor holds the functional entropy for a batch
98
        """
99
        loss = torch.nn.functional.normalize(loss, p=1, dim=1)
100
        loss = torch.mean(loss * loss.log()) - (torch.mean(loss) * torch.mean(loss).log())
101
102
        return loss
103
104
    @classmethod
105
    def get_batch_statistics(cls, loss: torch.Tensor, n_samples: int, estimation: str = 'ent') -> torch.Tensor:
106
        """
107
        Calculate the expectation of the batch gradient
108
        :param n_samples:
109
        :param loss:
110
        :param estimation:
111
        :return: Influence expectation
112
        """
113
        loss = loss.reshape(-1, n_samples)
114
115
        if estimation == 'var':
116
            batch_statistics = cls._get_variance(loss)
117
            batch_statistics = torch.abs(batch_statistics)
118
        elif estimation == 'ent':
119
            batch_statistics = cls._get_functional_entropy(loss)
120
        elif estimation == 'dif_ent':
121
            batch_statistics = cls._get_differential_entropy(loss)
122
        else:
123
            raise NotImplementedError(f'{estimation} is unknown regularization, please use "var" or "ent".')
124
125
        return torch.mean(batch_statistics)
126
127
    @classmethod
128
    def get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor:
129
        """
130
        Calculate the expectation of the batch gradient
131
        :param loss:
132
        :param estimation:
133
        :param grad: tensor holds the gradient batch
134
        :return: approximation of the required expectation
135
        """
136
        batch_grad_norm = torch.norm(grad, p=2, dim=1)
137
        batch_grad_norm = torch.pow(batch_grad_norm, 2)
138
139
        if estimation == 'ent':
140
            batch_grad_norm = batch_grad_norm / loss
141
142
        return torch.mean(batch_grad_norm)
143
144
    @classmethod
145
    def _get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor:
146
        """
147
        Calculate the expectation of the batch gradient
148
        :param loss:
149
        :param estimation:
150
        :param grad: tensor holds the gradient batch
151
        :return: approximation of the required expectation
152
        """
153
        batch_grad_norm = torch.norm(grad, p=2, dim=1)
154
        batch_grad_norm = torch.pow(batch_grad_norm, 2)
155
156
        if estimation == 'ent':
157
            batch_grad_norm = batch_grad_norm / loss
158
159
        return batch_grad_norm
160
161
    @classmethod
162
    def _get_max_ent(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor:
163
        """
164
        Calculate the norm of 1 divided by the information
165
        :param inf_scores: tensor holding batch information scores
166
        :param norm: which norm to use
167
        :return:
168
        """
169
        return torch.norm(torch.div(1, inf_scores), p=norm)
170
171
    @classmethod
172
    def _get_max_ent_minus(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor:
173
        """
174
        Calculate -1 * the norm of the information
175
        :param inf_scores: tensor holding batch information scores
176
        :param norm: which norm to use
177
        :return:
178
        """
179
        return -1 * torch.norm(inf_scores, p=norm) + 0.1
180
181
    @classmethod
182
    def get_regularization_term(cls, inf_scores: torch.Tensor, norm: float = 2.0,
183
                                optim_method: str = 'max_ent') -> torch.Tensor:
184
        """
185
        Compute the regularization term given a batch of information scores
186
        :param inf_scores: tensor holding a batch of information scores
187
        :param norm: defines which norm to use (1 or 2)
188
        :param optim_method: Define optimization method (possible methods: "min_ent", "max_ent", "max_ent_minus",
189
         "normalized")
190
        :return:
191
        """
192
193
        if optim_method == 'max_ent':
194
            return cls._get_max_ent(inf_scores, norm)
195
        elif optim_method == 'min_ent':
196
            return torch.norm(inf_scores, p=norm)
197
        elif optim_method == 'max_ent_minus':
198
            return cls._get_max_ent_minus(inf_scores, norm)
199
200
        raise NotImplementedError(f'"{optim_method}" is unknown')
201
202
203
class RegParameters(object):
204
    """
205
    This class controls all the regularization-related properties
206
    """
207
    def __init__(self, lambda_: float = 1e-10, norm: float = 2.0, estimation: str = 'ent',
208
                 optim_method: str = 'max_ent', n_samples: int = 10, grad: bool = True):
209
        self.lambda_ = lambda_
210
        self.norm = norm
211
        self.estimation = estimation
212
        self.optim_method = optim_method
213
        self.n_samples = n_samples
214
        self.grad = grad