Diff of /sybil/utils/losses.py [000000] .. [d9566e]

Switch to unified view

a b/sybil/utils/losses.py
1
from collections import OrderedDict
2
3
import torch
4
import torch.nn.functional as F
5
6
7
def get_cross_entropy_loss(model_output, batch, model, args):
8
    logging_dict, predictions = OrderedDict(), OrderedDict()
9
    logit = model_output["logit"]
10
    loss = F.cross_entropy(logit, batch["y"].long())
11
    logging_dict["cross_entropy_loss"] = loss.detach()
12
    predictions["probs"] = F.softmax(logit, dim=-1).detach()
13
    predictions["golds"] = batch["y"]
14
    return loss, logging_dict, predictions
15
16
17
def get_survival_loss(model_output, batch, model, args):
18
    logging_dict, predictions = OrderedDict(), OrderedDict()
19
    logit = model_output["logit"]
20
    y_seq, y_mask = batch["y_seq"], batch["y_mask"]
21
    loss = F.binary_cross_entropy_with_logits(logit, y_seq.float(), weight=y_mask.float(), reduction='sum') / torch.sum(y_mask.float())
22
    logging_dict["survival_loss"] = loss.detach()
23
    predictions["probs"] = torch.sigmoid(logit).detach()
24
    predictions["golds"] = batch["y"]
25
    predictions["censors"] = batch["time_at_event"]
26
    return loss, logging_dict, predictions
27
28
29
def get_annotation_loss(model_output, batch, model, args):
30
    total_loss, logging_dict, predictions = 0, OrderedDict(), OrderedDict()
31
32
    B, _, N, H, W, = model_output["activ"].shape
33
34
    batch_mask = batch["has_annotation"]
35
36
    for attn_num in [1, 2]:
37
38
        side_attn = -1
39
        if model_output.get("image_attention_{}".format(attn_num), None) is not None:
40
            if len(batch["image_annotations"].shape) == 4:
41
                batch["image_annotations"] = batch["image_annotations"].unsqueeze(1)
42
43
            # resize annotation to 'activ' size
44
            annotation_gold = F.interpolate(
45
                batch["image_annotations"], (N, H, W), mode="area"
46
            )
47
            annotation_gold = annotation_gold * batch_mask[:, None, None, None, None]
48
49
            # renormalize scores
50
            mask_area = annotation_gold.sum(dim=(-1, -2)).unsqueeze(-1).unsqueeze(-1)
51
            mask_area[mask_area == 0] = 1
52
            annotation_gold /= mask_area
53
54
            # reshape annotation into 1D vector
55
            annotation_gold = annotation_gold.view(B, N, -1).float()
56
57
            # get mask over annotation boxes in order to weigh
58
            # non-annotated scores with zero when computing loss
59
            annotation_gold_mask = (annotation_gold > 0).float()
60
61
            num_annotated_samples = (annotation_gold.view(B * N, -1).sum(-1) > 0).sum()
62
            num_annotated_samples = max(1, num_annotated_samples)
63
64
            pred_attn = (
65
                model_output["image_attention_{}".format(attn_num)]
66
                * batch_mask[:, None, None]
67
            )
68
            kldiv = (
69
                F.kl_div(pred_attn, annotation_gold, reduction="none")
70
                * annotation_gold_mask
71
            )
72
73
            # sum loss per volume and average over batches
74
            loss = kldiv.sum() / num_annotated_samples
75
            logging_dict["image_attention_loss_{}".format(attn_num)] = loss.detach()
76
            total_loss += args.image_attention_loss_lambda * loss
77
            
78
            # attend to cancer side
79
            cancer_side_mask = (batch["cancer_laterality"][:, :2].sum(-1) == 1).float()[
80
                :, None
81
            ]  # only one side is positive
82
            cancer_side_gold = (
83
                batch["cancer_laterality"][:, 1].unsqueeze(1).repeat(1, N)
84
            )  # left side (seen as lung on right) is positive class
85
            num_annotated_samples = max(N * cancer_side_mask.sum(), 1)
86
            side_attn = torch.exp(model_output["image_attention_{}".format(attn_num)])
87
            side_attn = side_attn.view(B, N, H, W)
88
            side_attn = torch.stack(
89
                [
90
                    side_attn[:, :, :, : W // 2].sum((2, 3)),
91
                    side_attn[:, :, :, W // 2 :].sum((2, 3)),
92
                ],
93
                dim=-1,
94
            )
95
            side_attn_log = F.log_softmax(side_attn, dim=-1).transpose(1, 2)
96
97
            loss = (
98
                F.cross_entropy(side_attn_log, cancer_side_gold, reduction="none")
99
                * cancer_side_mask
100
            ).sum() / num_annotated_samples
101
            logging_dict[
102
                "image_side_attention_loss_{}".format(attn_num)
103
            ] = loss.detach()
104
            total_loss += args.image_attention_loss_lambda * loss
105
106
        if model_output.get("volume_attention_{}".format(attn_num), None) is not None:
107
            # find size of annotation box per slice and normalize
108
            annotation_gold = batch["annotation_areas"].float() * batch_mask[:, None]
109
110
            if N != args.num_images:
111
                annotation_gold = F.interpolate(annotation_gold.unsqueeze(1), (N), mode= 'linear', align_corners = True)[:,0]
112
            area_per_slice = annotation_gold.sum(-1).unsqueeze(-1)
113
            area_per_slice[area_per_slice == 0] = 1
114
            annotation_gold /= area_per_slice
115
116
            num_annotated_samples = (annotation_gold.sum(-1) > 0).sum()
117
            num_annotated_samples = max(1, num_annotated_samples)
118
119
            # find slices with annotation
120
            annotation_gold_mask = (annotation_gold > 0).float()
121
122
            pred_attn = (
123
                model_output["volume_attention_{}".format(attn_num)]
124
                * batch_mask[:, None]
125
            )
126
            kldiv = (
127
                F.kl_div(pred_attn, annotation_gold, reduction="none")
128
                * annotation_gold_mask
129
            )  # B, N
130
            loss = kldiv.sum() / num_annotated_samples
131
132
            logging_dict["volume_attention_loss_{}".format(attn_num)] = loss.detach()
133
            total_loss += args.volume_attention_loss_lambda * loss
134
            
135
            if isinstance(side_attn, torch.Tensor):
136
                # attend to cancer side
137
                cancer_side_mask = (
138
                    batch["cancer_laterality"][:, :2].sum(-1) == 1
139
                ).float()  # only one side is positive
140
                cancer_side_gold = batch["cancer_laterality"][
141
                    :, 1
142
                ]  # left side (seen as lung on right) is positive class
143
                num_annotated_samples = max(cancer_side_mask.sum(), 1)
144
145
                pred_attn = torch.exp(
146
                    model_output["volume_attention_{}".format(attn_num)]
147
                )
148
                side_attn = (side_attn * pred_attn.unsqueeze(-1)).sum(1)
149
                side_attn_log = F.log_softmax(side_attn, dim=-1)
150
151
                loss = (
152
                    F.cross_entropy(side_attn_log, cancer_side_gold, reduction="none")
153
                    * cancer_side_mask
154
                ).sum() / num_annotated_samples
155
                logging_dict[
156
                    "volume_side_attention_loss_{}".format(attn_num)
157
                ] = loss.detach()
158
                total_loss += args.volume_attention_loss_lambda * loss
159
160
    return total_loss * args.annotation_loss_lambda, logging_dict, predictions
161
162
163
def get_risk_factor_loss(model_output, batch, model, args):
164
    total_loss, logging_dict, predictions = 0, OrderedDict(), OrderedDict()
165
166
    for idx, key in enumerate(args.risk_factor_keys):
167
        logit = model_output["{}_logit".format(key)]
168
        gold_rf = batch["risk_factors"][idx]
169
        is_rf_known = (torch.sum(gold_rf, dim=-1) > 0).unsqueeze(-1).float()
170
171
        gold = torch.argmax(gold_rf, dim=-1).contiguous().view(-1)
172
173
        loss = (
174
            F.cross_entropy(logit, gold, reduction="none") * is_rf_known
175
        ).sum() / max(1, is_rf_known.sum())
176
        total_loss += loss
177
        logging_dict["{}_loss".format(key)] = loss.detach()
178
179
        probs = F.softmax(logit, dim=-1).detach()
180
        predictions["{}_probs".format(key)] = probs.detach()
181
        predictions["{}_golds".format(key)] = gold.detach()
182
        predictions["{}_risk_factor".format(key)] = batch["risk_factors"][idx]
183
        # preds = torch.argmax(probs, dim=-1).view(-1)
184
185
    return total_loss * args.primary_loss_lambda, logging_dict, predictions
186
187
def discriminator_loss(model_output, batch, model, args):
188
    logging_dict, predictions = OrderedDict(), OrderedDict()
189
    d_output = model.discriminator(model_output, batch)
190
    loss = F.cross_entropy(d_output['logit'], batch['origin_dataset'].long()) * args.adv_loss_lambda
191
    logging_dict['discrim_loss'] = loss.detach()
192
    predictions['discrim_probs'] = d_output['logit'].detach()
193
    predictions['discrim_golds'] = batch['origin_dataset']
194
195
    if model.reverse_discrim_loss:
196
        loss = -loss
197
        
198
    return loss, logging_dict, predictions