[068794]: / hubconf.py

Download this file

24 lines (17 with data), 773 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
dependencies = ["torch"]
import torch
from unet import UNet
def unet(pretrained=False, **kwargs):
"""
U-Net segmentation model with batch normalization for biomedical image segmentation
pretrained (bool): load pretrained weights into the model
in_channels (int): number of input channels
out_channels (int): number of output channels
init_features (int): number of feature-maps in the first encoder layer
"""
model = UNet(**kwargs)
if pretrained:
checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt"
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu')
model.load_state_dict(state_dict)
return model