|
a |
|
b/rocaseg/components/mixup.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
""" |
|
|
5 |
Example usage: |
|
|
6 |
|
|
|
7 |
# Regular segmentation loss: |
|
|
8 |
ys_pred_oai = self.models['segm'](xs_oai) |
|
|
9 |
loss_segm = self.losses['segm'](input_=ys_pred_oai, |
|
|
10 |
target=ys_true_arg_oai) |
|
|
11 |
|
|
|
12 |
# Mixup |
|
|
13 |
xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data( |
|
|
14 |
x=xs_oai, y=ys_true_arg_oai, |
|
|
15 |
alpha=self.config['mixup_alpha'], device=maybe_gpu) |
|
|
16 |
ys_pred_oai = self.models['segm'](xs_mixup) |
|
|
17 |
loss_segm = mixup_criterion(criterion=self.losses['segm'], |
|
|
18 |
pred=ys_pred_oai, |
|
|
19 |
y_a=ys_mixup_a, |
|
|
20 |
y_b=ys_mixup_b, |
|
|
21 |
lam=lambda_mixup) |
|
|
22 |
""" |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def mixup_data(x, y, alpha=1.0, device='cpu'): |
|
|
26 |
"""Returns mixed inputs, pairs of targets, and lambda""" |
|
|
27 |
if alpha > 0: |
|
|
28 |
lam = np.random.beta(alpha, alpha) |
|
|
29 |
else: |
|
|
30 |
lam = 1 |
|
|
31 |
|
|
|
32 |
batch_size = x.size()[0] |
|
|
33 |
index = torch.randperm(batch_size).to(device) |
|
|
34 |
|
|
|
35 |
mixed_x = lam * x + (1 - lam) * x[index, :] |
|
|
36 |
y_a, y_b = y, y[index] |
|
|
37 |
return mixed_x, y_a, y_b, lam |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
def mixup_criterion(criterion, pred, y_a, y_b, lam): |
|
|
41 |
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) |