--- a +++ b/model/resnet.py @@ -0,0 +1,22 @@ +import torch.nn as nn + +from torchvision import models + + +class ResnetModel(nn.Module): + + def __init__(self, classes): + """ + Arguments + --------- + classes: número de clases (tipos de hemorragias) + """ + super(ResnetModel, self).__init__() + self.backbone = models.resnet50(pretrained=False) + n_filters = self.backbone.fc.in_features + self.backbone.fc = nn.Linear(n_filters, classes) + + def forward(self, x): + x = self.backbone(x) + + return x