a | b/model/resnet.py | ||
---|---|---|---|
1 | import torch.nn as nn |
||
2 | |||
3 | from torchvision import models |
||
4 | |||
5 | |||
6 | class ResnetModel(nn.Module): |
||
7 | |||
8 | def __init__(self, classes): |
||
9 | """ |
||
10 | Arguments |
||
11 | --------- |
||
12 | classes: número de clases (tipos de hemorragias) |
||
13 | """ |
||
14 | super(ResnetModel, self).__init__() |
||
15 | self.backbone = models.resnet50(pretrained=False) |
||
16 | n_filters = self.backbone.fc.in_features |
||
17 | self.backbone.fc = nn.Linear(n_filters, classes) |
||
18 | |||
19 | def forward(self, x): |
||
20 | x = self.backbone(x) |
||
21 | |||
22 | return x |