a b/Generation/loss.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import logging
9
import torch
10
import torch.distributed.nn
11
from torch import distributed as dist, nn as nn
12
from torch.nn import functional as F
13
14
try:
15
    import horovod.torch as hvd
16
except ImportError:
17
    hvd = None
18
19
20
def gather_features(
21
    image_features,
22
    text_features,
23
    local_loss=False,
24
    gather_with_grad=False,
25
    rank=0,
26
    world_size=1,
27
    use_horovod=False,
28
):
29
    if use_horovod:
30
        assert hvd is not None, "Please install horovod"
31
        if gather_with_grad:
32
            all_image_features = hvd.allgather(image_features)
33
            all_text_features = hvd.allgather(text_features)
34
        else:
35
            with torch.no_grad():
36
                all_image_features = hvd.allgather(image_features)
37
                all_text_features = hvd.allgather(text_features)
38
            if not local_loss:
39
                # ensure grads for local rank when all_* features don't have a gradient
40
                gathered_image_features = list(
41
                    all_image_features.chunk(world_size, dim=0)
42
                )
43
                gathered_text_features = list(
44
                    all_text_features.chunk(world_size, dim=0)
45
                )
46
                gathered_image_features[rank] = image_features
47
                gathered_text_features[rank] = text_features
48
                all_image_features = torch.cat(gathered_image_features, dim=0)
49
                all_text_features = torch.cat(gathered_text_features, dim=0)
50
    else:
51
        # We gather tensors from all gpus
52
        if gather_with_grad:
53
            all_image_features = torch.cat(
54
                torch.distributed.nn.all_gather(image_features), dim=0
55
            )
56
            all_text_features = torch.cat(
57
                torch.distributed.nn.all_gather(text_features), dim=0
58
            )
59
        else:
60
            gathered_image_features = [
61
                torch.zeros_like(image_features) for _ in range(world_size)
62
            ]
63
            gathered_text_features = [
64
                torch.zeros_like(text_features) for _ in range(world_size)
65
            ]
66
            dist.all_gather(gathered_image_features, image_features)
67
            dist.all_gather(gathered_text_features, text_features)
68
            if not local_loss:
69
                # ensure grads for local rank when all_* features don't have a gradient
70
                gathered_image_features[rank] = image_features
71
                gathered_text_features[rank] = text_features
72
            all_image_features = torch.cat(gathered_image_features, dim=0)
73
            all_text_features = torch.cat(gathered_text_features, dim=0)
74
75
    return all_image_features, all_text_features
76
77
78
class ClipLoss(nn.Module):
79
    def __init__(
80
        self,
81
        local_loss=False,
82
        gather_with_grad=False,
83
        cache_labels=False,
84
        rank=0,
85
        world_size=1,
86
        use_horovod=False,
87
    ):
88
        super().__init__()
89
        self.local_loss = local_loss
90
        self.gather_with_grad = gather_with_grad
91
        self.cache_labels = cache_labels
92
        self.rank = rank
93
        self.world_size = world_size
94
        self.use_horovod = use_horovod
95
96
        # cache state
97
        self.prev_num_logits = 0
98
        self.labels = {}
99
100
    def forward(self, image_features, text_features, logit_scale):
101
        device = image_features.device
102
        if self.world_size > 1:
103
            all_image_features, all_text_features = gather_features(
104
                image_features,
105
                text_features,
106
                self.local_loss,
107
                self.gather_with_grad,
108
                self.rank,
109
                self.world_size,
110
                self.use_horovod,
111
            )
112
113
            if self.local_loss:
114
                logits_per_image = logit_scale * image_features @ all_text_features.T
115
                logits_per_text = logit_scale * text_features @ all_image_features.T
116
            else:
117
                logits_per_image = (
118
                    logit_scale * all_image_features @ all_text_features.T
119
                )
120
                logits_per_text = logits_per_image.T
121
        else:
122
            logits_per_image = logit_scale * image_features @ text_features.T
123
            logits_per_text = logit_scale * text_features @ image_features.T
124
125
        # calculated ground-truth and cache if enabled
126
        num_logits = logits_per_image.shape[0]
127
        if self.prev_num_logits != num_logits or device not in self.labels:
128
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
129
            if self.world_size > 1 and self.local_loss:
130
                labels = labels + num_logits * self.rank
131
            if self.cache_labels:
132
                self.labels[device] = labels
133
                self.prev_num_logits = num_logits
134
        else:
135
            labels = self.labels[device]
136
137
        total_loss = (
138
            F.cross_entropy(logits_per_image, labels)
139
            + F.cross_entropy(logits_per_text, labels)
140
        ) / 2
141
        return total_loss