|
a |
|
b/hubconf.py |
|
|
1 |
dependencies = ["torch"] |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
from unet import UNet |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def unet(pretrained=False, **kwargs): |
|
|
9 |
""" |
|
|
10 |
U-Net segmentation model with batch normalization for biomedical image segmentation |
|
|
11 |
pretrained (bool): load pretrained weights into the model |
|
|
12 |
in_channels (int): number of input channels |
|
|
13 |
out_channels (int): number of output channels |
|
|
14 |
init_features (int): number of feature-maps in the first encoder layer |
|
|
15 |
""" |
|
|
16 |
model = UNet(**kwargs) |
|
|
17 |
|
|
|
18 |
if pretrained: |
|
|
19 |
checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" |
|
|
20 |
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu') |
|
|
21 |
model.load_state_dict(state_dict) |
|
|
22 |
|
|
|
23 |
return model |