--- a
+++ b/model.py
@@ -0,0 +1,20 @@
+import torch
+import torch.nn as nn
+from torchvision import models
+
+
+class MRNet(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.pretrained_model = models.alexnet(pretrained=True)
+        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
+        self.classifer = nn.Linear(256, 2)
+
+    def forward(self, x):
+        x = torch.squeeze(x, dim=0) 
+        features = self.pretrained_model.features(x)
+        pooled_features = self.pooling_layer(features)
+        pooled_features = pooled_features.view(pooled_features.size(0), -1)
+        flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]
+        output = self.classifer(flattened_features)
+        return output