Diff of /biovil_t/resnet.py [000000] .. [27805f]

Switch to side-by-side view

--- a
+++ b/biovil_t/resnet.py
@@ -0,0 +1,64 @@
+#%%
+from typing import Any, List, Tuple, Type, Union
+
+import torch
+from torch.hub import load_state_dict_from_url
+from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
+from torchvision.models import ResNet18_Weights, ResNet50_Weights
+
+TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
+
+
+class ResNetHIML(ResNet):
+  """Wrapper class of the original torchvision ResNet model.
+
+  The forward function is updated to return the penultimate layer
+  activations, which are required to obtain image patch embeddings.
+  """
+
+  def __init__(self, **kwargs: Any) -> None:
+      super().__init__(**kwargs)
+
+  def forward(self, x: torch.Tensor,
+              return_intermediate_layers: bool = False) -> Union[torch.Tensor, TypeSkipConnections]:
+      """ResNetHIML forward pass. Optionally returns intermediate layers using the
+      return_intermediate_layers argument.
+
+      :param return_intermediate_layers: If True, return layers x0-x4 as a tuple,
+          otherwise return x4 only.
+      """
+
+      x0 = self.conv1(x)
+      x0 = self.bn1(x0)
+      x0 = self.relu(x0)
+      x0 = self.maxpool(x0)
+
+      x1 = self.layer1(x0)
+      x2 = self.layer2(x1)
+      x3 = self.layer3(x2)
+      x4 = self.layer4(x3)
+
+      if return_intermediate_layers:
+          return x0, x1, x2, x3, x4
+      else:
+          return x4
+
+
+def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int],
+          pretrained: bool, progress: bool, **kwargs: Any) -> ResNetHIML:
+  """Instantiate a custom :class:ResNet model.
+
+  Adapted from :mod:torchvision.models.resnet.
+  """
+  model = ResNetHIML(block=block, layers=layers, **kwargs)
+  if pretrained:
+      if arch == 'resnet18':
+          weights = ResNet18_Weights.IMAGENET1K_V1
+      elif arch == 'resnet50':
+          weights = ResNet50_Weights.IMAGENET1K_V1
+      else:
+          raise ValueError(f"Pretrained model not available for {arch}")
+      
+      state_dict = weights.get_state_dict(progress=progress)
+      model.load_state_dict(state_dict)
+  return model
\ No newline at end of file