Diff of /hubconf.py [000000] .. [9cc651]

Switch to unified view

a b/hubconf.py
1
dependencies = ["torch"]
2
3
import torch
4
5
from unet import UNet
6
7
8
def unet(pretrained=False, **kwargs):
9
    """
10
    U-Net segmentation model with batch normalization for biomedical image segmentation
11
    pretrained (bool): load pretrained weights into the model
12
    in_channels (int): number of input channels
13
    out_channels (int): number of output channels
14
    init_features (int): number of feature-maps in the first encoder layer
15
    """
16
    model = UNet(**kwargs)
17
18
    if pretrained:
19
        checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt"
20
        state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu')
21
        model.load_state_dict(state_dict)
22
23
    return model