|
a |
|
b/utils/mm_regularization.py |
|
|
1 |
""" |
|
|
2 |
An implementation of the paper: "Removing Bias in Multi-modal Classifiers: Regularization by Maximizing Functional |
|
|
3 |
Entropies" NeurIPS 2020. |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import torch |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class Perturbation: |
|
|
10 |
""" |
|
|
11 |
Class that in charge of the perturbation techniques |
|
|
12 |
""" |
|
|
13 |
@classmethod |
|
|
14 |
def _add_noise_to_tensor(cls, tens: torch.Tensor, over_dim: int = 0) -> torch.Tensor: |
|
|
15 |
""" |
|
|
16 |
Adds noise to a tensor sampled from N(0, tens.std()). |
|
|
17 |
:param tens: |
|
|
18 |
:param over_dim: over what dim to calculate the std. 0 for features over batch, 1 for over sample. |
|
|
19 |
:return: noisy tensor in the same shape as input |
|
|
20 |
""" |
|
|
21 |
|
|
|
22 |
return tens + torch.randn_like(tens) * tens.std(dim=over_dim) |
|
|
23 |
# return tens + torch.randn_like(tens) |
|
|
24 |
|
|
|
25 |
@classmethod |
|
|
26 |
def perturb_tensor(cls, tens: torch.Tensor, n_samples: int, perturbation: bool = True) -> torch.Tensor: |
|
|
27 |
""" |
|
|
28 |
Flatting the tensor, expanding it, perturbing and reconstructing to the original shape. |
|
|
29 |
Note, this function assumes that the batch is the first dimension. |
|
|
30 |
:param tens: |
|
|
31 |
:param n_samples: times to perturb |
|
|
32 |
:param perturbation: False - only duplicating the tensor |
|
|
33 |
:return: tensor in the shape of [batch, samples * num_eval_samples] |
|
|
34 |
""" |
|
|
35 |
tens_dim = list(tens.shape) |
|
|
36 |
|
|
|
37 |
tens = tens.view(tens.shape[0], -1) |
|
|
38 |
tens = tens.repeat(1, n_samples) |
|
|
39 |
|
|
|
40 |
tens = tens.view(tens.shape[0] * n_samples, -1) |
|
|
41 |
|
|
|
42 |
if perturbation: |
|
|
43 |
tens = cls._add_noise_to_tensor(tens) |
|
|
44 |
|
|
|
45 |
tens_dim[0] *= n_samples |
|
|
46 |
|
|
|
47 |
tens = tens.view(*tens_dim) |
|
|
48 |
tens.requires_grad_() |
|
|
49 |
|
|
|
50 |
return tens |
|
|
51 |
|
|
|
52 |
@classmethod |
|
|
53 |
def get_expanded_logits(cls, logits: torch.Tensor, n_samples: int, logits_flg: bool = True) -> torch.Tensor: |
|
|
54 |
""" |
|
|
55 |
Perform Softmax and then expand the logits depends on the num_eval_samples |
|
|
56 |
:param logits_flg: whether the input is logits or softmax |
|
|
57 |
:param logits: tensor holds logits outputs from the model |
|
|
58 |
:param n_samples: times to duplicate |
|
|
59 |
:return: |
|
|
60 |
""" |
|
|
61 |
if logits_flg: |
|
|
62 |
logits = torch.nn.functional.softmax(logits, dim=1) |
|
|
63 |
expanded_logits = logits.repeat(1, n_samples) |
|
|
64 |
|
|
|
65 |
return expanded_logits.view(expanded_logits.shape[0] * n_samples, -1) |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
class Regularization(object): |
|
|
69 |
""" |
|
|
70 |
Class that in charge of the regularization techniques |
|
|
71 |
""" |
|
|
72 |
@classmethod |
|
|
73 |
def _get_variance(cls, loss: torch.Tensor) -> torch.Tensor: |
|
|
74 |
""" |
|
|
75 |
Computes the variance along samples for the first dimension in a tensor |
|
|
76 |
:param loss: [batch, number of evaluate samples] |
|
|
77 |
:return: variance of a given batch of loss values |
|
|
78 |
""" |
|
|
79 |
|
|
|
80 |
return torch.var(loss, dim=1) |
|
|
81 |
|
|
|
82 |
@classmethod |
|
|
83 |
def _get_differential_entropy(cls, loss: torch.Tensor) -> torch.Tensor: |
|
|
84 |
""" |
|
|
85 |
Computes differential entropy: -E[flogf] |
|
|
86 |
:param loss: |
|
|
87 |
:return: a tensor holds the differential entropy for a batch |
|
|
88 |
""" |
|
|
89 |
|
|
|
90 |
return -1 * torch.sum(loss * loss.log()) |
|
|
91 |
|
|
|
92 |
@classmethod |
|
|
93 |
def _get_functional_entropy(cls, loss: torch.Tensor) -> torch.Tensor: |
|
|
94 |
""" |
|
|
95 |
Computes functional entropy: E[flogf] - E[f]logE[f] |
|
|
96 |
:param loss: |
|
|
97 |
:return: a tensor holds the functional entropy for a batch |
|
|
98 |
""" |
|
|
99 |
loss = torch.nn.functional.normalize(loss, p=1, dim=1) |
|
|
100 |
loss = torch.mean(loss * loss.log()) - (torch.mean(loss) * torch.mean(loss).log()) |
|
|
101 |
|
|
|
102 |
return loss |
|
|
103 |
|
|
|
104 |
@classmethod |
|
|
105 |
def get_batch_statistics(cls, loss: torch.Tensor, n_samples: int, estimation: str = 'ent') -> torch.Tensor: |
|
|
106 |
""" |
|
|
107 |
Calculate the expectation of the batch gradient |
|
|
108 |
:param n_samples: |
|
|
109 |
:param loss: |
|
|
110 |
:param estimation: |
|
|
111 |
:return: Influence expectation |
|
|
112 |
""" |
|
|
113 |
loss = loss.reshape(-1, n_samples) |
|
|
114 |
|
|
|
115 |
if estimation == 'var': |
|
|
116 |
batch_statistics = cls._get_variance(loss) |
|
|
117 |
batch_statistics = torch.abs(batch_statistics) |
|
|
118 |
elif estimation == 'ent': |
|
|
119 |
batch_statistics = cls._get_functional_entropy(loss) |
|
|
120 |
elif estimation == 'dif_ent': |
|
|
121 |
batch_statistics = cls._get_differential_entropy(loss) |
|
|
122 |
else: |
|
|
123 |
raise NotImplementedError(f'{estimation} is unknown regularization, please use "var" or "ent".') |
|
|
124 |
|
|
|
125 |
return torch.mean(batch_statistics) |
|
|
126 |
|
|
|
127 |
@classmethod |
|
|
128 |
def get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor: |
|
|
129 |
""" |
|
|
130 |
Calculate the expectation of the batch gradient |
|
|
131 |
:param loss: |
|
|
132 |
:param estimation: |
|
|
133 |
:param grad: tensor holds the gradient batch |
|
|
134 |
:return: approximation of the required expectation |
|
|
135 |
""" |
|
|
136 |
batch_grad_norm = torch.norm(grad, p=2, dim=1) |
|
|
137 |
batch_grad_norm = torch.pow(batch_grad_norm, 2) |
|
|
138 |
|
|
|
139 |
if estimation == 'ent': |
|
|
140 |
batch_grad_norm = batch_grad_norm / loss |
|
|
141 |
|
|
|
142 |
return torch.mean(batch_grad_norm) |
|
|
143 |
|
|
|
144 |
@classmethod |
|
|
145 |
def _get_batch_norm(cls, grad: torch.Tensor, loss: torch.Tensor = None, estimation: str = 'ent') -> torch.Tensor: |
|
|
146 |
""" |
|
|
147 |
Calculate the expectation of the batch gradient |
|
|
148 |
:param loss: |
|
|
149 |
:param estimation: |
|
|
150 |
:param grad: tensor holds the gradient batch |
|
|
151 |
:return: approximation of the required expectation |
|
|
152 |
""" |
|
|
153 |
batch_grad_norm = torch.norm(grad, p=2, dim=1) |
|
|
154 |
batch_grad_norm = torch.pow(batch_grad_norm, 2) |
|
|
155 |
|
|
|
156 |
if estimation == 'ent': |
|
|
157 |
batch_grad_norm = batch_grad_norm / loss |
|
|
158 |
|
|
|
159 |
return batch_grad_norm |
|
|
160 |
|
|
|
161 |
@classmethod |
|
|
162 |
def _get_max_ent(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor: |
|
|
163 |
""" |
|
|
164 |
Calculate the norm of 1 divided by the information |
|
|
165 |
:param inf_scores: tensor holding batch information scores |
|
|
166 |
:param norm: which norm to use |
|
|
167 |
:return: |
|
|
168 |
""" |
|
|
169 |
return torch.norm(torch.div(1, inf_scores), p=norm) |
|
|
170 |
|
|
|
171 |
@classmethod |
|
|
172 |
def _get_max_ent_minus(cls, inf_scores: torch.Tensor, norm: float) -> torch.Tensor: |
|
|
173 |
""" |
|
|
174 |
Calculate -1 * the norm of the information |
|
|
175 |
:param inf_scores: tensor holding batch information scores |
|
|
176 |
:param norm: which norm to use |
|
|
177 |
:return: |
|
|
178 |
""" |
|
|
179 |
return -1 * torch.norm(inf_scores, p=norm) + 0.1 |
|
|
180 |
|
|
|
181 |
@classmethod |
|
|
182 |
def get_regularization_term(cls, inf_scores: torch.Tensor, norm: float = 2.0, |
|
|
183 |
optim_method: str = 'max_ent') -> torch.Tensor: |
|
|
184 |
""" |
|
|
185 |
Compute the regularization term given a batch of information scores |
|
|
186 |
:param inf_scores: tensor holding a batch of information scores |
|
|
187 |
:param norm: defines which norm to use (1 or 2) |
|
|
188 |
:param optim_method: Define optimization method (possible methods: "min_ent", "max_ent", "max_ent_minus", |
|
|
189 |
"normalized") |
|
|
190 |
:return: |
|
|
191 |
""" |
|
|
192 |
|
|
|
193 |
if optim_method == 'max_ent': |
|
|
194 |
return cls._get_max_ent(inf_scores, norm) |
|
|
195 |
elif optim_method == 'min_ent': |
|
|
196 |
return torch.norm(inf_scores, p=norm) |
|
|
197 |
elif optim_method == 'max_ent_minus': |
|
|
198 |
return cls._get_max_ent_minus(inf_scores, norm) |
|
|
199 |
|
|
|
200 |
raise NotImplementedError(f'"{optim_method}" is unknown') |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
class RegParameters(object): |
|
|
204 |
""" |
|
|
205 |
This class controls all the regularization-related properties |
|
|
206 |
""" |
|
|
207 |
def __init__(self, lambda_: float = 1e-10, norm: float = 2.0, estimation: str = 'ent', |
|
|
208 |
optim_method: str = 'max_ent', n_samples: int = 10, grad: bool = True): |
|
|
209 |
self.lambda_ = lambda_ |
|
|
210 |
self.norm = norm |
|
|
211 |
self.estimation = estimation |
|
|
212 |
self.optim_method = optim_method |
|
|
213 |
self.n_samples = n_samples |
|
|
214 |
self.grad = grad |