|
a |
|
b/helper.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
|
|
|
4 |
# Example model definition (use your actual architecture) |
|
|
5 |
class UNet(nn.Module): |
|
|
6 |
def __init__(self): |
|
|
7 |
super(UNet, self).__init__() |
|
|
8 |
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
|
|
9 |
self.relu = nn.ReLU() |
|
|
10 |
self.conv2 = nn.Conv2d(64, 1, kernel_size=3, padding=1) |
|
|
11 |
|
|
|
12 |
def forward(self, x): |
|
|
13 |
x = self.relu(self.conv1(x)) |
|
|
14 |
x = self.conv2(x) |
|
|
15 |
return x |
|
|
16 |
|
|
|
17 |
# Step 1: Create the model instance |
|
|
18 |
model = UNet() |
|
|
19 |
|
|
|
20 |
# Step 2: Load the state dictionary |
|
|
21 |
state_dict = torch.load('leukemia_cells_unet.pt', map_location=torch.device('cpu'), weights_only=True) |
|
|
22 |
|
|
|
23 |
# Step 3: Load the weights into the model instance |
|
|
24 |
model.load_state_dict(state_dict) |
|
|
25 |
|
|
|
26 |
# Step 4: Set the model to evaluation mode |
|
|
27 |
model.eval() |