Diff of /model/resnet.py [000000] .. [ccc736]

Switch to unified view

a b/model/resnet.py
1
import torch.nn as nn
2
3
from torchvision import models
4
5
6
class ResnetModel(nn.Module):
7
8
    def __init__(self, classes):
9
        """
10
        Arguments
11
        ---------
12
        classes: número de clases (tipos de hemorragias)
13
        """
14
        super(ResnetModel, self).__init__()
15
        self.backbone = models.resnet50(pretrained=False)
16
        n_filters = self.backbone.fc.in_features
17
        self.backbone.fc = nn.Linear(n_filters, classes)
18
19
    def forward(self, x):
20
        x = self.backbone(x)
21
22
        return x