Diff of /model.py [000000] .. [fa8046]

Switch to unified view

a b/model.py
1
import torch
2
import torch.nn as nn
3
from torchvision import models
4
5
6
class MRNet(nn.Module):
7
    def __init__(self):
8
        super().__init__()
9
        self.pretrained_model = models.alexnet(pretrained=True)
10
        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
11
        self.classifer = nn.Linear(256, 2)
12
13
    def forward(self, x):
14
        x = torch.squeeze(x, dim=0) 
15
        features = self.pretrained_model.features(x)
16
        pooled_features = self.pooling_layer(features)
17
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
18
        flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]
19
        output = self.classifer(flattened_features)
20
        return output