[7e6383]: / old / Pytorch-Covid.py

Download this file

57 lines (49 with data), 2.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
## The code below gives you Flatten and the double Adaptive Pooling (from fastai), plus
## a viable head. You must fill the number of FC's nodes manually through the myhead function
from torch import Tensor
from torch import nn
import torch
import torchvision
import logging as log
from typing import Optional # required for "Optional[type]"
class Flatten(nn.Module):
"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
def __init__(self, full:bool=False):
super().__init__()
self.full = full
def forward(self, x):
return x.view(-1) if self.full else x.view(x.size(0), -1)
class AdaptiveConcatPool2d(nn.Module):
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`." # from pytorch
def __init__(self, sz:Optional[int]=None):
"Output will be 2*sz or 2 if sz is None"
super().__init__()
self.output_size = sz or 1
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
def myhead(nf, nc):
return \
nn.Sequential( # the dropout is needed otherwise you cannot load the weights
AdaptiveConcatPool2d(),
Flatten(),
nn.BatchNorm1d(nf,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True),
nn.Dropout(p=0.25,inplace=False),
nn.Linear(nf, 512,bias=True),
nn.ReLU(True),
nn.BatchNorm1d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True),
nn.Dropout(p=0.5,inplace=False),
nn.Linear(512, nc,bias=True),
)
my_model=torchvision.models.resnet50()
modules=list(my_model.children())
modules.pop(-1)
modules.pop(-1)
temp=nn.Sequential(nn.Sequential(*modules))
tempchildren=list(temp.children())
#append the special fastai head
#Configured according to Model Architecture
tempchildren.append(myhead(4096,2))
model_r50=nn.Sequential(*tempchildren)
state = torch.load(path+'Corona_model_stage5.pth')
model_r50.load_state_dict(state['model'])