|
a |
|
b/model.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from torchvision import models |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class MRNet(nn.Module): |
|
|
7 |
def __init__(self): |
|
|
8 |
super().__init__() |
|
|
9 |
self.pretrained_model = models.alexnet(pretrained=True) |
|
|
10 |
self.pooling_layer = nn.AdaptiveAvgPool2d(1) |
|
|
11 |
self.classifer = nn.Linear(256, 2) |
|
|
12 |
|
|
|
13 |
def forward(self, x): |
|
|
14 |
x = torch.squeeze(x, dim=0) |
|
|
15 |
features = self.pretrained_model.features(x) |
|
|
16 |
pooled_features = self.pooling_layer(features) |
|
|
17 |
pooled_features = pooled_features.view(pooled_features.size(0), -1) |
|
|
18 |
flattened_features = torch.max(pooled_features, 0, keepdim=True)[0] |
|
|
19 |
output = self.classifer(flattened_features) |
|
|
20 |
return output |