Diff of /utils/my_model.py [000000] .. [190ca4]

Switch to side-by-side view

--- a
+++ b/utils/my_model.py
@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.optim import SGD, lr_scheduler
+
+torch.backends.cudnn.benchmark = False  # You can set it to True if you experience performance gains
+torch.backends.cudnn.deterministic = False
+from src.loss_functions.losses import AsymmetricLoss, ASLSingleLabel
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+import torch.nn.functional as F
+
+class MyCNN(nn.Module):
+    def __init__(self, num_classes=12, dropout_prob=0.2, in_channels=3):
+        super(MyCNN, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1)
+        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
+        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
+        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
+        self.fc1 = nn.Linear(64 *3* 3, 1024)
+        self.fc2 = nn.Linear(1024, 256)
+        self.fc3 = nn.Linear(256, num_classes)
+
+        # Dropout layers
+        self.dropout1 = nn.Dropout(p=dropout_prob)
+        self.dropout2 = nn.Dropout(p=dropout_prob)
+        
+    
+    def forward(self, x_input):
+        # Apply convolutional and pooling layers
+        x = F.leaky_relu(self.conv1(x_input))
+        x = F.max_pool2d(x, 2)
+        x = F.leaky_relu(self.conv2(x))
+        x = F.max_pool2d(x, 2)
+        x = F.leaky_relu(self.conv3(x))
+        x = F.max_pool2d(x, 2)
+
+        # Flatten the output for the fully connected layers
+        x = x.view(x.size(0), -1)
+        x = self.dropout1(x)
+        x = F.leaky_relu(self.fc1(x))
+        x = self.dropout2(x)
+        x = F.leaky_relu(self.fc2(x))
+
+        # Apply fully connected layers
+    
+        x = self.dropout2(x)
+        x = self.fc3(x)
+        return x
+
+# Rest of the code remains unchanged
+
+# Initialize the model
+cell_attribute_model = MyCNN(num_classes=12, dropout_prob=0.5, in_channels=256).to(device)
+cell_attribute_model.train()  # Set the model in training mode
+
+# Initialize optimizer, criterion, and scheduler
+optimizer_cell_model = torch.optim.SGD(cell_attribute_model.parameters(), lr=0.01, weight_decay=0.01)
+step_size = 5
+gamma = 0.1
+scheduler_cell_model = lr_scheduler.StepLR(optimizer_cell_model, step_size=step_size, gamma=gamma)
+#criterion = nn.CrossEntropyLoss()
+criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=1, clip=0.08, disable_torch_grad_focal_loss=True)
+# criterion = ASLSingleLabel()
+
+
+# /num_classes = 2
+#criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss
+
+def cell_training(cell_attribute_model_main,cell_datas, labels):
+    obj_batch_size = len(cell_datas)
+     # Set the model in training mode
+    #optimizer_cell_model.zero_grad()
+
+        # Filter out instances with label=2 and their corresponding cell_datas
+    # Filter out rows where any element in the row (excluding the first column) is equal to 2
+    valid_indices = [i for i, row in enumerate(labels[:,1:]) if not torch.any(row[1:] == 2).item()]
+    
+    if not valid_indices:
+        # print("No valid instances, skipping training.")
+        object_batch_loss = torch.tensor(0.0, requires_grad=True, device=device)  # Initialize as a torch.Tensor
+
+        return  object_batch_loss
+
+    filtered_cell_datas = [cell_datas[i] for i in valid_indices]
+    filtered_labels = labels[:,1:][valid_indices]
+
+    # Assuming each element in filtered_cell_datas is a tensor of shape (in_channels, height, width)
+    cell_images = torch.stack(filtered_cell_datas).to(device)
+    cell_datas_batch = cell_images.squeeze(1)
+    filtered_labels = filtered_labels.to(device)
+
+    # Initialize the model with the dynamically determined in_channels
+    # in_channels = filtered_cell_datas[0].size(1)  # Assuming the first element in filtered_cell_datas defines in_channels
+    # cell_attribute_model_main.conv1.in_channels = in_channels
+
+    # Forward pass
+    outputs_my = cell_attribute_model_main(cell_datas_batch.float())
+    outputs_my = outputs_my.view(len(valid_indices), -1)
+
+    # Process labels to create target_tensor
+    # label_att = filtered_labels[:, 5].float()  # Assuming label[5] contains 0 or 1
+    # target_tensor = label_att.view(-1, 1)
+
+    # Compute the loss
+    num_classes = 2
+    one_hot_encoded_tensors = []
+
+    # Perform one-hot encoding for each column separately
+    for i in range(filtered_labels.size(1)):
+        # Extract the current column
+        column_values = filtered_labels[:, i].long()
+
+        # Generate one-hot encoded tensor for the current column
+        one_hot_encoded_col = torch.eye(num_classes, device=filtered_labels.device)[column_values]
+
+        # Reshape to match the original shape
+        one_hot_encoded_col = one_hot_encoded_col.unsqueeze(1)
+
+        one_hot_encoded_tensors.append(one_hot_encoded_col)
+
+    # Concatenate the one-hot encoded tensors along the second dimension (axis=1)
+    one_hot_encoded_result = torch.cat(one_hot_encoded_tensors, dim=1)
+    outputs_my = outputs_my.view(outputs_my.size(0), 6,2)
+
+    object_batch_loss = criterion(outputs_my, one_hot_encoded_result)
+
+    # Check if the loss contains NaN
+    if torch.isnan(object_batch_loss):
+        # If NaN, trigger a breakpoint to inspect variables
+        breakpoint()
+
+    torch.use_deterministic_algorithms(False, warn_only=True)
+
+    # Backward pass and optimization
+    object_batch_loss = object_batch_loss/len(filtered_labels)
+   # object_batch_loss.backward(retain_graph=True)
+   # optimizer_cell_model.step()
+    #scheduler_cell_model.step()
+
+    # Explicitly release tensors
+    #del cell_images, target_tensor
+    #torch.cuda.empty_cache()
+
+    return object_batch_loss
+