|
a |
|
b/utils/my_model.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from torch.optim import SGD, lr_scheduler |
|
|
5 |
|
|
|
6 |
torch.backends.cudnn.benchmark = False # You can set it to True if you experience performance gains |
|
|
7 |
torch.backends.cudnn.deterministic = False |
|
|
8 |
from src.loss_functions.losses import AsymmetricLoss, ASLSingleLabel |
|
|
9 |
|
|
|
10 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
11 |
|
|
|
12 |
import torch.nn.functional as F |
|
|
13 |
|
|
|
14 |
class MyCNN(nn.Module): |
|
|
15 |
def __init__(self, num_classes=12, dropout_prob=0.2, in_channels=3): |
|
|
16 |
super(MyCNN, self).__init__() |
|
|
17 |
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) |
|
|
18 |
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1) |
|
|
19 |
self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1) |
|
|
20 |
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) |
|
|
21 |
self.fc1 = nn.Linear(64 *3* 3, 1024) |
|
|
22 |
self.fc2 = nn.Linear(1024, 256) |
|
|
23 |
self.fc3 = nn.Linear(256, num_classes) |
|
|
24 |
|
|
|
25 |
# Dropout layers |
|
|
26 |
self.dropout1 = nn.Dropout(p=dropout_prob) |
|
|
27 |
self.dropout2 = nn.Dropout(p=dropout_prob) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def forward(self, x_input): |
|
|
31 |
# Apply convolutional and pooling layers |
|
|
32 |
x = F.leaky_relu(self.conv1(x_input)) |
|
|
33 |
x = F.max_pool2d(x, 2) |
|
|
34 |
x = F.leaky_relu(self.conv2(x)) |
|
|
35 |
x = F.max_pool2d(x, 2) |
|
|
36 |
x = F.leaky_relu(self.conv3(x)) |
|
|
37 |
x = F.max_pool2d(x, 2) |
|
|
38 |
|
|
|
39 |
# Flatten the output for the fully connected layers |
|
|
40 |
x = x.view(x.size(0), -1) |
|
|
41 |
x = self.dropout1(x) |
|
|
42 |
x = F.leaky_relu(self.fc1(x)) |
|
|
43 |
x = self.dropout2(x) |
|
|
44 |
x = F.leaky_relu(self.fc2(x)) |
|
|
45 |
|
|
|
46 |
# Apply fully connected layers |
|
|
47 |
|
|
|
48 |
x = self.dropout2(x) |
|
|
49 |
x = self.fc3(x) |
|
|
50 |
return x |
|
|
51 |
|
|
|
52 |
# Rest of the code remains unchanged |
|
|
53 |
|
|
|
54 |
# Initialize the model |
|
|
55 |
cell_attribute_model = MyCNN(num_classes=12, dropout_prob=0.5, in_channels=256).to(device) |
|
|
56 |
cell_attribute_model.train() # Set the model in training mode |
|
|
57 |
|
|
|
58 |
# Initialize optimizer, criterion, and scheduler |
|
|
59 |
optimizer_cell_model = torch.optim.SGD(cell_attribute_model.parameters(), lr=0.01, weight_decay=0.01) |
|
|
60 |
step_size = 5 |
|
|
61 |
gamma = 0.1 |
|
|
62 |
scheduler_cell_model = lr_scheduler.StepLR(optimizer_cell_model, step_size=step_size, gamma=gamma) |
|
|
63 |
#criterion = nn.CrossEntropyLoss() |
|
|
64 |
criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=1, clip=0.08, disable_torch_grad_focal_loss=True) |
|
|
65 |
# criterion = ASLSingleLabel() |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
# /num_classes = 2 |
|
|
69 |
#criterion = nn.BCEWithLogitsLoss() # Binary Cross-Entropy Loss |
|
|
70 |
|
|
|
71 |
def cell_training(cell_attribute_model_main,cell_datas, labels): |
|
|
72 |
obj_batch_size = len(cell_datas) |
|
|
73 |
# Set the model in training mode |
|
|
74 |
#optimizer_cell_model.zero_grad() |
|
|
75 |
|
|
|
76 |
# Filter out instances with label=2 and their corresponding cell_datas |
|
|
77 |
# Filter out rows where any element in the row (excluding the first column) is equal to 2 |
|
|
78 |
valid_indices = [i for i, row in enumerate(labels[:,1:]) if not torch.any(row[1:] == 2).item()] |
|
|
79 |
|
|
|
80 |
if not valid_indices: |
|
|
81 |
# print("No valid instances, skipping training.") |
|
|
82 |
object_batch_loss = torch.tensor(0.0, requires_grad=True, device=device) # Initialize as a torch.Tensor |
|
|
83 |
|
|
|
84 |
return object_batch_loss |
|
|
85 |
|
|
|
86 |
filtered_cell_datas = [cell_datas[i] for i in valid_indices] |
|
|
87 |
filtered_labels = labels[:,1:][valid_indices] |
|
|
88 |
|
|
|
89 |
# Assuming each element in filtered_cell_datas is a tensor of shape (in_channels, height, width) |
|
|
90 |
cell_images = torch.stack(filtered_cell_datas).to(device) |
|
|
91 |
cell_datas_batch = cell_images.squeeze(1) |
|
|
92 |
filtered_labels = filtered_labels.to(device) |
|
|
93 |
|
|
|
94 |
# Initialize the model with the dynamically determined in_channels |
|
|
95 |
# in_channels = filtered_cell_datas[0].size(1) # Assuming the first element in filtered_cell_datas defines in_channels |
|
|
96 |
# cell_attribute_model_main.conv1.in_channels = in_channels |
|
|
97 |
|
|
|
98 |
# Forward pass |
|
|
99 |
outputs_my = cell_attribute_model_main(cell_datas_batch.float()) |
|
|
100 |
outputs_my = outputs_my.view(len(valid_indices), -1) |
|
|
101 |
|
|
|
102 |
# Process labels to create target_tensor |
|
|
103 |
# label_att = filtered_labels[:, 5].float() # Assuming label[5] contains 0 or 1 |
|
|
104 |
# target_tensor = label_att.view(-1, 1) |
|
|
105 |
|
|
|
106 |
# Compute the loss |
|
|
107 |
num_classes = 2 |
|
|
108 |
one_hot_encoded_tensors = [] |
|
|
109 |
|
|
|
110 |
# Perform one-hot encoding for each column separately |
|
|
111 |
for i in range(filtered_labels.size(1)): |
|
|
112 |
# Extract the current column |
|
|
113 |
column_values = filtered_labels[:, i].long() |
|
|
114 |
|
|
|
115 |
# Generate one-hot encoded tensor for the current column |
|
|
116 |
one_hot_encoded_col = torch.eye(num_classes, device=filtered_labels.device)[column_values] |
|
|
117 |
|
|
|
118 |
# Reshape to match the original shape |
|
|
119 |
one_hot_encoded_col = one_hot_encoded_col.unsqueeze(1) |
|
|
120 |
|
|
|
121 |
one_hot_encoded_tensors.append(one_hot_encoded_col) |
|
|
122 |
|
|
|
123 |
# Concatenate the one-hot encoded tensors along the second dimension (axis=1) |
|
|
124 |
one_hot_encoded_result = torch.cat(one_hot_encoded_tensors, dim=1) |
|
|
125 |
outputs_my = outputs_my.view(outputs_my.size(0), 6,2) |
|
|
126 |
|
|
|
127 |
object_batch_loss = criterion(outputs_my, one_hot_encoded_result) |
|
|
128 |
|
|
|
129 |
# Check if the loss contains NaN |
|
|
130 |
if torch.isnan(object_batch_loss): |
|
|
131 |
# If NaN, trigger a breakpoint to inspect variables |
|
|
132 |
breakpoint() |
|
|
133 |
|
|
|
134 |
torch.use_deterministic_algorithms(False, warn_only=True) |
|
|
135 |
|
|
|
136 |
# Backward pass and optimization |
|
|
137 |
object_batch_loss = object_batch_loss/len(filtered_labels) |
|
|
138 |
# object_batch_loss.backward(retain_graph=True) |
|
|
139 |
# optimizer_cell_model.step() |
|
|
140 |
#scheduler_cell_model.step() |
|
|
141 |
|
|
|
142 |
# Explicitly release tensors |
|
|
143 |
#del cell_images, target_tensor |
|
|
144 |
#torch.cuda.empty_cache() |
|
|
145 |
|
|
|
146 |
return object_batch_loss |
|
|
147 |
|