Diff of /helper.py [000000] .. [d366d1]

Switch to unified view

a b/helper.py
1
import torch
2
import torch.nn as nn
3
4
# Example model definition (use your actual architecture)
5
class UNet(nn.Module):
6
    def __init__(self):
7
        super(UNet, self).__init__()
8
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
9
        self.relu = nn.ReLU()
10
        self.conv2 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
11
12
    def forward(self, x):
13
        x = self.relu(self.conv1(x))
14
        x = self.conv2(x)
15
        return x
16
17
# Step 1: Create the model instance
18
model = UNet()
19
20
# Step 2: Load the state dictionary
21
state_dict = torch.load('leukemia_cells_unet.pt', map_location=torch.device('cpu'), weights_only=True)
22
23
# Step 3: Load the weights into the model instance
24
model.load_state_dict(state_dict)
25
26
# Step 4: Set the model to evaluation mode
27
model.eval()