|
a |
|
b/sybil/utils/losses.py |
|
|
1 |
from collections import OrderedDict |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def get_cross_entropy_loss(model_output, batch, model, args): |
|
|
8 |
logging_dict, predictions = OrderedDict(), OrderedDict() |
|
|
9 |
logit = model_output["logit"] |
|
|
10 |
loss = F.cross_entropy(logit, batch["y"].long()) |
|
|
11 |
logging_dict["cross_entropy_loss"] = loss.detach() |
|
|
12 |
predictions["probs"] = F.softmax(logit, dim=-1).detach() |
|
|
13 |
predictions["golds"] = batch["y"] |
|
|
14 |
return loss, logging_dict, predictions |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def get_survival_loss(model_output, batch, model, args): |
|
|
18 |
logging_dict, predictions = OrderedDict(), OrderedDict() |
|
|
19 |
logit = model_output["logit"] |
|
|
20 |
y_seq, y_mask = batch["y_seq"], batch["y_mask"] |
|
|
21 |
loss = F.binary_cross_entropy_with_logits(logit, y_seq.float(), weight=y_mask.float(), reduction='sum') / torch.sum(y_mask.float()) |
|
|
22 |
logging_dict["survival_loss"] = loss.detach() |
|
|
23 |
predictions["probs"] = torch.sigmoid(logit).detach() |
|
|
24 |
predictions["golds"] = batch["y"] |
|
|
25 |
predictions["censors"] = batch["time_at_event"] |
|
|
26 |
return loss, logging_dict, predictions |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def get_annotation_loss(model_output, batch, model, args): |
|
|
30 |
total_loss, logging_dict, predictions = 0, OrderedDict(), OrderedDict() |
|
|
31 |
|
|
|
32 |
B, _, N, H, W, = model_output["activ"].shape |
|
|
33 |
|
|
|
34 |
batch_mask = batch["has_annotation"] |
|
|
35 |
|
|
|
36 |
for attn_num in [1, 2]: |
|
|
37 |
|
|
|
38 |
side_attn = -1 |
|
|
39 |
if model_output.get("image_attention_{}".format(attn_num), None) is not None: |
|
|
40 |
if len(batch["image_annotations"].shape) == 4: |
|
|
41 |
batch["image_annotations"] = batch["image_annotations"].unsqueeze(1) |
|
|
42 |
|
|
|
43 |
# resize annotation to 'activ' size |
|
|
44 |
annotation_gold = F.interpolate( |
|
|
45 |
batch["image_annotations"], (N, H, W), mode="area" |
|
|
46 |
) |
|
|
47 |
annotation_gold = annotation_gold * batch_mask[:, None, None, None, None] |
|
|
48 |
|
|
|
49 |
# renormalize scores |
|
|
50 |
mask_area = annotation_gold.sum(dim=(-1, -2)).unsqueeze(-1).unsqueeze(-1) |
|
|
51 |
mask_area[mask_area == 0] = 1 |
|
|
52 |
annotation_gold /= mask_area |
|
|
53 |
|
|
|
54 |
# reshape annotation into 1D vector |
|
|
55 |
annotation_gold = annotation_gold.view(B, N, -1).float() |
|
|
56 |
|
|
|
57 |
# get mask over annotation boxes in order to weigh |
|
|
58 |
# non-annotated scores with zero when computing loss |
|
|
59 |
annotation_gold_mask = (annotation_gold > 0).float() |
|
|
60 |
|
|
|
61 |
num_annotated_samples = (annotation_gold.view(B * N, -1).sum(-1) > 0).sum() |
|
|
62 |
num_annotated_samples = max(1, num_annotated_samples) |
|
|
63 |
|
|
|
64 |
pred_attn = ( |
|
|
65 |
model_output["image_attention_{}".format(attn_num)] |
|
|
66 |
* batch_mask[:, None, None] |
|
|
67 |
) |
|
|
68 |
kldiv = ( |
|
|
69 |
F.kl_div(pred_attn, annotation_gold, reduction="none") |
|
|
70 |
* annotation_gold_mask |
|
|
71 |
) |
|
|
72 |
|
|
|
73 |
# sum loss per volume and average over batches |
|
|
74 |
loss = kldiv.sum() / num_annotated_samples |
|
|
75 |
logging_dict["image_attention_loss_{}".format(attn_num)] = loss.detach() |
|
|
76 |
total_loss += args.image_attention_loss_lambda * loss |
|
|
77 |
|
|
|
78 |
# attend to cancer side |
|
|
79 |
cancer_side_mask = (batch["cancer_laterality"][:, :2].sum(-1) == 1).float()[ |
|
|
80 |
:, None |
|
|
81 |
] # only one side is positive |
|
|
82 |
cancer_side_gold = ( |
|
|
83 |
batch["cancer_laterality"][:, 1].unsqueeze(1).repeat(1, N) |
|
|
84 |
) # left side (seen as lung on right) is positive class |
|
|
85 |
num_annotated_samples = max(N * cancer_side_mask.sum(), 1) |
|
|
86 |
side_attn = torch.exp(model_output["image_attention_{}".format(attn_num)]) |
|
|
87 |
side_attn = side_attn.view(B, N, H, W) |
|
|
88 |
side_attn = torch.stack( |
|
|
89 |
[ |
|
|
90 |
side_attn[:, :, :, : W // 2].sum((2, 3)), |
|
|
91 |
side_attn[:, :, :, W // 2 :].sum((2, 3)), |
|
|
92 |
], |
|
|
93 |
dim=-1, |
|
|
94 |
) |
|
|
95 |
side_attn_log = F.log_softmax(side_attn, dim=-1).transpose(1, 2) |
|
|
96 |
|
|
|
97 |
loss = ( |
|
|
98 |
F.cross_entropy(side_attn_log, cancer_side_gold, reduction="none") |
|
|
99 |
* cancer_side_mask |
|
|
100 |
).sum() / num_annotated_samples |
|
|
101 |
logging_dict[ |
|
|
102 |
"image_side_attention_loss_{}".format(attn_num) |
|
|
103 |
] = loss.detach() |
|
|
104 |
total_loss += args.image_attention_loss_lambda * loss |
|
|
105 |
|
|
|
106 |
if model_output.get("volume_attention_{}".format(attn_num), None) is not None: |
|
|
107 |
# find size of annotation box per slice and normalize |
|
|
108 |
annotation_gold = batch["annotation_areas"].float() * batch_mask[:, None] |
|
|
109 |
|
|
|
110 |
if N != args.num_images: |
|
|
111 |
annotation_gold = F.interpolate(annotation_gold.unsqueeze(1), (N), mode= 'linear', align_corners = True)[:,0] |
|
|
112 |
area_per_slice = annotation_gold.sum(-1).unsqueeze(-1) |
|
|
113 |
area_per_slice[area_per_slice == 0] = 1 |
|
|
114 |
annotation_gold /= area_per_slice |
|
|
115 |
|
|
|
116 |
num_annotated_samples = (annotation_gold.sum(-1) > 0).sum() |
|
|
117 |
num_annotated_samples = max(1, num_annotated_samples) |
|
|
118 |
|
|
|
119 |
# find slices with annotation |
|
|
120 |
annotation_gold_mask = (annotation_gold > 0).float() |
|
|
121 |
|
|
|
122 |
pred_attn = ( |
|
|
123 |
model_output["volume_attention_{}".format(attn_num)] |
|
|
124 |
* batch_mask[:, None] |
|
|
125 |
) |
|
|
126 |
kldiv = ( |
|
|
127 |
F.kl_div(pred_attn, annotation_gold, reduction="none") |
|
|
128 |
* annotation_gold_mask |
|
|
129 |
) # B, N |
|
|
130 |
loss = kldiv.sum() / num_annotated_samples |
|
|
131 |
|
|
|
132 |
logging_dict["volume_attention_loss_{}".format(attn_num)] = loss.detach() |
|
|
133 |
total_loss += args.volume_attention_loss_lambda * loss |
|
|
134 |
|
|
|
135 |
if isinstance(side_attn, torch.Tensor): |
|
|
136 |
# attend to cancer side |
|
|
137 |
cancer_side_mask = ( |
|
|
138 |
batch["cancer_laterality"][:, :2].sum(-1) == 1 |
|
|
139 |
).float() # only one side is positive |
|
|
140 |
cancer_side_gold = batch["cancer_laterality"][ |
|
|
141 |
:, 1 |
|
|
142 |
] # left side (seen as lung on right) is positive class |
|
|
143 |
num_annotated_samples = max(cancer_side_mask.sum(), 1) |
|
|
144 |
|
|
|
145 |
pred_attn = torch.exp( |
|
|
146 |
model_output["volume_attention_{}".format(attn_num)] |
|
|
147 |
) |
|
|
148 |
side_attn = (side_attn * pred_attn.unsqueeze(-1)).sum(1) |
|
|
149 |
side_attn_log = F.log_softmax(side_attn, dim=-1) |
|
|
150 |
|
|
|
151 |
loss = ( |
|
|
152 |
F.cross_entropy(side_attn_log, cancer_side_gold, reduction="none") |
|
|
153 |
* cancer_side_mask |
|
|
154 |
).sum() / num_annotated_samples |
|
|
155 |
logging_dict[ |
|
|
156 |
"volume_side_attention_loss_{}".format(attn_num) |
|
|
157 |
] = loss.detach() |
|
|
158 |
total_loss += args.volume_attention_loss_lambda * loss |
|
|
159 |
|
|
|
160 |
return total_loss * args.annotation_loss_lambda, logging_dict, predictions |
|
|
161 |
|
|
|
162 |
|
|
|
163 |
def get_risk_factor_loss(model_output, batch, model, args): |
|
|
164 |
total_loss, logging_dict, predictions = 0, OrderedDict(), OrderedDict() |
|
|
165 |
|
|
|
166 |
for idx, key in enumerate(args.risk_factor_keys): |
|
|
167 |
logit = model_output["{}_logit".format(key)] |
|
|
168 |
gold_rf = batch["risk_factors"][idx] |
|
|
169 |
is_rf_known = (torch.sum(gold_rf, dim=-1) > 0).unsqueeze(-1).float() |
|
|
170 |
|
|
|
171 |
gold = torch.argmax(gold_rf, dim=-1).contiguous().view(-1) |
|
|
172 |
|
|
|
173 |
loss = ( |
|
|
174 |
F.cross_entropy(logit, gold, reduction="none") * is_rf_known |
|
|
175 |
).sum() / max(1, is_rf_known.sum()) |
|
|
176 |
total_loss += loss |
|
|
177 |
logging_dict["{}_loss".format(key)] = loss.detach() |
|
|
178 |
|
|
|
179 |
probs = F.softmax(logit, dim=-1).detach() |
|
|
180 |
predictions["{}_probs".format(key)] = probs.detach() |
|
|
181 |
predictions["{}_golds".format(key)] = gold.detach() |
|
|
182 |
predictions["{}_risk_factor".format(key)] = batch["risk_factors"][idx] |
|
|
183 |
# preds = torch.argmax(probs, dim=-1).view(-1) |
|
|
184 |
|
|
|
185 |
return total_loss * args.primary_loss_lambda, logging_dict, predictions |
|
|
186 |
|
|
|
187 |
def discriminator_loss(model_output, batch, model, args): |
|
|
188 |
logging_dict, predictions = OrderedDict(), OrderedDict() |
|
|
189 |
d_output = model.discriminator(model_output, batch) |
|
|
190 |
loss = F.cross_entropy(d_output['logit'], batch['origin_dataset'].long()) * args.adv_loss_lambda |
|
|
191 |
logging_dict['discrim_loss'] = loss.detach() |
|
|
192 |
predictions['discrim_probs'] = d_output['logit'].detach() |
|
|
193 |
predictions['discrim_golds'] = batch['origin_dataset'] |
|
|
194 |
|
|
|
195 |
if model.reverse_discrim_loss: |
|
|
196 |
loss = -loss |
|
|
197 |
|
|
|
198 |
return loss, logging_dict, predictions |