|
a |
|
b/biovil_t/modules.py |
|
|
1 |
# ------------------------------------------------------------------------------------------- |
|
|
2 |
# Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. |
|
|
4 |
# ------------------------------------------------------------------------------------------- |
|
|
5 |
|
|
|
6 |
from typing import Callable, Optional |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
import torch.nn as nn |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class MLP(nn.Module): |
|
|
13 |
""" |
|
|
14 |
Fully connected layers to map between image embeddings and projection space where pairs of images are compared. |
|
|
15 |
|
|
|
16 |
:param input_dim: Input embedding feature size |
|
|
17 |
:param hidden_dim: Hidden layer size in MLP |
|
|
18 |
:param output_dim: Output projection size |
|
|
19 |
:param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency. |
|
|
20 |
""" |
|
|
21 |
|
|
|
22 |
def __init__(self, |
|
|
23 |
input_dim: int, |
|
|
24 |
output_dim: int, |
|
|
25 |
hidden_dim: Optional[int] = None, |
|
|
26 |
use_1x1_convs: bool = False) -> None: |
|
|
27 |
super().__init__() |
|
|
28 |
|
|
|
29 |
if use_1x1_convs: |
|
|
30 |
linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False} |
|
|
31 |
linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True} |
|
|
32 |
normalisation_layer: Callable = nn.BatchNorm2d |
|
|
33 |
projection_layer: Callable = nn.Conv2d |
|
|
34 |
else: |
|
|
35 |
linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False} |
|
|
36 |
linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True} |
|
|
37 |
normalisation_layer = nn.BatchNorm1d |
|
|
38 |
projection_layer = nn.Linear |
|
|
39 |
|
|
|
40 |
self.output_dim = output_dim |
|
|
41 |
self.input_dim = input_dim |
|
|
42 |
if hidden_dim is not None: |
|
|
43 |
self.model = nn.Sequential( |
|
|
44 |
projection_layer(**linear_proj_1_args), |
|
|
45 |
normalisation_layer(hidden_dim), |
|
|
46 |
nn.ReLU(inplace=True), |
|
|
47 |
projection_layer(**linear_proj_2_args)) |
|
|
48 |
else: |
|
|
49 |
self.model = nn.Linear(input_dim, output_dim) # type: ignore |
|
|
50 |
|
|
|
51 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
52 |
"""forward pass of the multi-layer perceptron""" |
|
|
53 |
x = self.model(x) |
|
|
54 |
return x |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
class MultiTaskModel(nn.Module): |
|
|
58 |
"""Torch module for multi-task classification heads. We create a separate classification head |
|
|
59 |
for each task and perform a forward pass on each head independently in forward(). Classification |
|
|
60 |
heads are instances of `MLP`. |
|
|
61 |
|
|
|
62 |
:param input_dim: Number of dimensions of the input feature map. |
|
|
63 |
:param classifier_hidden_dim: Number of dimensions of hidden features in the MLP. |
|
|
64 |
:param num_classes: Number of output classes per task. |
|
|
65 |
:param num_tasks: Number of classification tasks or heads required. |
|
|
66 |
""" |
|
|
67 |
|
|
|
68 |
def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int): |
|
|
69 |
|
|
|
70 |
super().__init__() |
|
|
71 |
|
|
|
72 |
self.num_classes = num_classes |
|
|
73 |
self.num_tasks = num_tasks |
|
|
74 |
|
|
|
75 |
for task in range(num_tasks): |
|
|
76 |
setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim)) |
|
|
77 |
|
|
|
78 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
79 |
"""Returns [batch_size, num_tasks, num_classes] tensor of logits.""" |
|
|
80 |
batch_size = x.shape[0] |
|
|
81 |
out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device) |
|
|
82 |
for task in range(self.num_tasks): |
|
|
83 |
classifier = getattr(self, "fc_" + str(task)) |
|
|
84 |
out[:, :, task] = classifier(x) |
|
|
85 |
return out |