[ccc736]: / model / resnet.py

Download this file

23 lines (16 with data), 520 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.nn as nn
from torchvision import models
class ResnetModel(nn.Module):
def __init__(self, classes):
"""
Arguments
---------
classes: número de clases (tipos de hemorragias)
"""
super(ResnetModel, self).__init__()
self.backbone = models.resnet50(pretrained=False)
n_filters = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(n_filters, classes)
def forward(self, x):
x = self.backbone(x)
return x