Diff of /hubconf.py [000000] .. [9cc651]

Switch to side-by-side view

--- a
+++ b/hubconf.py
@@ -0,0 +1,23 @@
+dependencies = ["torch"]
+
+import torch
+
+from unet import UNet
+
+
+def unet(pretrained=False, **kwargs):
+    """
+    U-Net segmentation model with batch normalization for biomedical image segmentation
+    pretrained (bool): load pretrained weights into the model
+    in_channels (int): number of input channels
+    out_channels (int): number of output channels
+    init_features (int): number of feature-maps in the first encoder layer
+    """
+    model = UNet(**kwargs)
+
+    if pretrained:
+        checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt"
+        state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu')
+        model.load_state_dict(state_dict)
+
+    return model