|
a |
|
b/Generation/loss.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import logging |
|
|
9 |
import torch |
|
|
10 |
import torch.distributed.nn |
|
|
11 |
from torch import distributed as dist, nn as nn |
|
|
12 |
from torch.nn import functional as F |
|
|
13 |
|
|
|
14 |
try: |
|
|
15 |
import horovod.torch as hvd |
|
|
16 |
except ImportError: |
|
|
17 |
hvd = None |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def gather_features( |
|
|
21 |
image_features, |
|
|
22 |
text_features, |
|
|
23 |
local_loss=False, |
|
|
24 |
gather_with_grad=False, |
|
|
25 |
rank=0, |
|
|
26 |
world_size=1, |
|
|
27 |
use_horovod=False, |
|
|
28 |
): |
|
|
29 |
if use_horovod: |
|
|
30 |
assert hvd is not None, "Please install horovod" |
|
|
31 |
if gather_with_grad: |
|
|
32 |
all_image_features = hvd.allgather(image_features) |
|
|
33 |
all_text_features = hvd.allgather(text_features) |
|
|
34 |
else: |
|
|
35 |
with torch.no_grad(): |
|
|
36 |
all_image_features = hvd.allgather(image_features) |
|
|
37 |
all_text_features = hvd.allgather(text_features) |
|
|
38 |
if not local_loss: |
|
|
39 |
# ensure grads for local rank when all_* features don't have a gradient |
|
|
40 |
gathered_image_features = list( |
|
|
41 |
all_image_features.chunk(world_size, dim=0) |
|
|
42 |
) |
|
|
43 |
gathered_text_features = list( |
|
|
44 |
all_text_features.chunk(world_size, dim=0) |
|
|
45 |
) |
|
|
46 |
gathered_image_features[rank] = image_features |
|
|
47 |
gathered_text_features[rank] = text_features |
|
|
48 |
all_image_features = torch.cat(gathered_image_features, dim=0) |
|
|
49 |
all_text_features = torch.cat(gathered_text_features, dim=0) |
|
|
50 |
else: |
|
|
51 |
# We gather tensors from all gpus |
|
|
52 |
if gather_with_grad: |
|
|
53 |
all_image_features = torch.cat( |
|
|
54 |
torch.distributed.nn.all_gather(image_features), dim=0 |
|
|
55 |
) |
|
|
56 |
all_text_features = torch.cat( |
|
|
57 |
torch.distributed.nn.all_gather(text_features), dim=0 |
|
|
58 |
) |
|
|
59 |
else: |
|
|
60 |
gathered_image_features = [ |
|
|
61 |
torch.zeros_like(image_features) for _ in range(world_size) |
|
|
62 |
] |
|
|
63 |
gathered_text_features = [ |
|
|
64 |
torch.zeros_like(text_features) for _ in range(world_size) |
|
|
65 |
] |
|
|
66 |
dist.all_gather(gathered_image_features, image_features) |
|
|
67 |
dist.all_gather(gathered_text_features, text_features) |
|
|
68 |
if not local_loss: |
|
|
69 |
# ensure grads for local rank when all_* features don't have a gradient |
|
|
70 |
gathered_image_features[rank] = image_features |
|
|
71 |
gathered_text_features[rank] = text_features |
|
|
72 |
all_image_features = torch.cat(gathered_image_features, dim=0) |
|
|
73 |
all_text_features = torch.cat(gathered_text_features, dim=0) |
|
|
74 |
|
|
|
75 |
return all_image_features, all_text_features |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
class ClipLoss(nn.Module): |
|
|
79 |
def __init__( |
|
|
80 |
self, |
|
|
81 |
local_loss=False, |
|
|
82 |
gather_with_grad=False, |
|
|
83 |
cache_labels=False, |
|
|
84 |
rank=0, |
|
|
85 |
world_size=1, |
|
|
86 |
use_horovod=False, |
|
|
87 |
): |
|
|
88 |
super().__init__() |
|
|
89 |
self.local_loss = local_loss |
|
|
90 |
self.gather_with_grad = gather_with_grad |
|
|
91 |
self.cache_labels = cache_labels |
|
|
92 |
self.rank = rank |
|
|
93 |
self.world_size = world_size |
|
|
94 |
self.use_horovod = use_horovod |
|
|
95 |
|
|
|
96 |
# cache state |
|
|
97 |
self.prev_num_logits = 0 |
|
|
98 |
self.labels = {} |
|
|
99 |
|
|
|
100 |
def forward(self, image_features, text_features, logit_scale): |
|
|
101 |
device = image_features.device |
|
|
102 |
if self.world_size > 1: |
|
|
103 |
all_image_features, all_text_features = gather_features( |
|
|
104 |
image_features, |
|
|
105 |
text_features, |
|
|
106 |
self.local_loss, |
|
|
107 |
self.gather_with_grad, |
|
|
108 |
self.rank, |
|
|
109 |
self.world_size, |
|
|
110 |
self.use_horovod, |
|
|
111 |
) |
|
|
112 |
|
|
|
113 |
if self.local_loss: |
|
|
114 |
logits_per_image = logit_scale * image_features @ all_text_features.T |
|
|
115 |
logits_per_text = logit_scale * text_features @ all_image_features.T |
|
|
116 |
else: |
|
|
117 |
logits_per_image = ( |
|
|
118 |
logit_scale * all_image_features @ all_text_features.T |
|
|
119 |
) |
|
|
120 |
logits_per_text = logits_per_image.T |
|
|
121 |
else: |
|
|
122 |
logits_per_image = logit_scale * image_features @ text_features.T |
|
|
123 |
logits_per_text = logit_scale * text_features @ image_features.T |
|
|
124 |
|
|
|
125 |
# calculated ground-truth and cache if enabled |
|
|
126 |
num_logits = logits_per_image.shape[0] |
|
|
127 |
if self.prev_num_logits != num_logits or device not in self.labels: |
|
|
128 |
labels = torch.arange(num_logits, device=device, dtype=torch.long) |
|
|
129 |
if self.world_size > 1 and self.local_loss: |
|
|
130 |
labels = labels + num_logits * self.rank |
|
|
131 |
if self.cache_labels: |
|
|
132 |
self.labels[device] = labels |
|
|
133 |
self.prev_num_logits = num_logits |
|
|
134 |
else: |
|
|
135 |
labels = self.labels[device] |
|
|
136 |
|
|
|
137 |
total_loss = ( |
|
|
138 |
F.cross_entropy(logits_per_image, labels) |
|
|
139 |
+ F.cross_entropy(logits_per_text, labels) |
|
|
140 |
) / 2 |
|
|
141 |
return total_loss |