a b/rocaseg/components/mixup.py
1
import torch
2
import numpy as np
3
4
"""
5
Example usage:
6
7
# Regular segmentation loss:
8
ys_pred_oai = self.models['segm'](xs_oai)
9
loss_segm = self.losses['segm'](input_=ys_pred_oai,
10
                                target=ys_true_arg_oai)
11
12
# Mixup
13
xs_mixup, ys_mixup_a, ys_mixup_b, lambda_mixup = mixup_data(
14
    x=xs_oai, y=ys_true_arg_oai,
15
    alpha=self.config['mixup_alpha'], device=maybe_gpu)
16
ys_pred_oai = self.models['segm'](xs_mixup)
17
loss_segm = mixup_criterion(criterion=self.losses['segm'],
18
                            pred=ys_pred_oai,
19
                            y_a=ys_mixup_a,
20
                            y_b=ys_mixup_b,
21
                            lam=lambda_mixup)
22
"""
23
24
25
def mixup_data(x, y, alpha=1.0, device='cpu'):
26
    """Returns mixed inputs, pairs of targets, and lambda"""
27
    if alpha > 0:
28
        lam = np.random.beta(alpha, alpha)
29
    else:
30
        lam = 1
31
32
    batch_size = x.size()[0]
33
    index = torch.randperm(batch_size).to(device)
34
35
    mixed_x = lam * x + (1 - lam) * x[index, :]
36
    y_a, y_b = y, y[index]
37
    return mixed_x, y_a, y_b, lam
38
39
40
def mixup_criterion(criterion, pred, y_a, y_b, lam):
41
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)