--- a +++ b/model.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +from torchvision import models + + +class MRNet(nn.Module): + def __init__(self): + super().__init__() + self.pretrained_model = models.alexnet(pretrained=True) + self.pooling_layer = nn.AdaptiveAvgPool2d(1) + self.classifer = nn.Linear(256, 2) + + def forward(self, x): + x = torch.squeeze(x, dim=0) + features = self.pretrained_model.features(x) + pooled_features = self.pooling_layer(features) + pooled_features = pooled_features.view(pooled_features.size(0), -1) + flattened_features = torch.max(pooled_features, 0, keepdim=True)[0] + output = self.classifer(flattened_features) + return output