In [1]:
# Install required libraries
!pip install onnx onnxruntime tf2onnx --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m102.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m104.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m455.8/455.8 kB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m60.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
grpcio-status 1.62.3 requires protobuf>=4.21.6, but you have protobuf 3.20.3 which is incompatible.

In [2]:
import torch
import numpy as np
import onnx
import onnxruntime as ort
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score
from google.colab import files

In [3]:
# Dataset Directories
train_dir = '/content/drive/MyDrive/update dataset/New Chest(5 class)/Train'
val_dir = '/content/drive/MyDrive/update dataset/New Chest(5 class)/Val'
test_dir = '/content/drive/MyDrive/update dataset/New Chest(5 class)/Test'

In [4]:
# Hyperparameters
num_epochs = 30
learning_rate = 1e-4
batch_size = 32
input_size = 224
patience = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# Data Augmentation and Transformations
transform_train = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(input_size),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val_test = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [6]:
# Datasets
train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)
val_dataset = datasets.ImageFolder(val_dir, transform=transform_val_test)
test_dataset = datasets.ImageFolder(test_dir, transform=transform_val_test)

In [7]:
# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
# Initialize ResNet101 model
model = models.resnet101(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, len(train_dataset.classes))
model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:01<00:00, 157MB/s]


In [9]:
# Print class names
class_names = train_dataset.classes
print(f"Classes: {class_names}")

Classes: ['Covid-19', 'Emphysema', 'Healthy', 'Pneumonia', 'Tuberculosis', 'random']


In [10]:
# Loss Function with Class Weights
class_counts = [0] * len(train_dataset.classes)
for _, label in train_dataset:
    class_counts[label] += 1
class_weights = torch.tensor(1.0 / np.array(class_counts), dtype=torch.float).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

In [11]:
# Optimizer and Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.1, verbose=True)



In [12]:
# Training and Validation Loop
checkpoint_path = "best_resnet101_model.pth"
best_val_accuracy = 0.0
epochs_no_improve = 0

In [13]:
for epoch in range(num_epochs):
    # Training
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    # Validation
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Metrics Calculation
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    train_loss = running_loss / total_train
    val_loss = val_loss / total_val
    train_accuracy = 100 * correct_train / total_train
    val_accuracy = 100 * correct_val / total_val

    scheduler.step(val_accuracy)

    # Checkpoint Saving and Early Stopping
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), checkpoint_path)
        epochs_no_improve = 0
        print(f"[INFO] Validation accuracy improved to {val_accuracy:.2f}%, saving model...")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print("[INFO] Early stopping triggered.")
        break

    # Epoch Summary
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, "
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%, "
          f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")

[INFO] Validation accuracy improved to 85.42%, saving model...
Epoch [1/30] Train Loss: 0.4114, Train Accuracy: 85.55%, Val Loss: 0.3875, Val Accuracy: 85.42%, Precision: 0.8717, Recall: 0.8542, F1-Score: 0.8530
[INFO] Validation accuracy improved to 90.08%, saving model...
Epoch [2/30] Train Loss: 0.2748, Train Accuracy: 90.27%, Val Loss: 0.2860, Val Accuracy: 90.08%, Precision: 0.9102, Recall: 0.9008, F1-Score: 0.9025
Epoch [3/30] Train Loss: 0.2394, Train Accuracy: 91.41%, Val Loss: 0.3527, Val Accuracy: 87.67%, Precision: 0.8897, Recall: 0.8767, F1-Score: 0.8772
Epoch [4/30] Train Loss: 0.2209, Train Accuracy: 92.11%, Val Loss: 0.2776, Val Accuracy: 89.58%, Precision: 0.9090, Recall: 0.8958, F1-Score: 0.8965
[INFO] Validation accuracy improved to 91.83%, saving model...
Epoch [5/30] Train Loss: 0.1964, Train Accuracy: 93.11%, Val Loss: 0.2259, Val Accuracy: 91.83%, Precision: 0.9243, Recall: 0.9183, F1-Score: 0.9194
Epoch [6/30] Train Loss: 0.1962, Train Accuracy: 93.02%, Val Loss:

In [14]:
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
model.eval()
print("Best model loaded from checkpoint")


  model.load_state_dict(torch.load(checkpoint_path))


Best model loaded from checkpoint


In [15]:
# Save the PyTorch model and state dictionary
torch.save(model, '/content/resnet101_full_model.pth')
torch.save(model.state_dict(), '/content/resnet101_state_dict.pth')


In [16]:
# Convert PyTorch Model to ONNX format
onnx_path = '/content/resnet101_model.onnx'
dummy_input = torch.randn(1, 3, input_size, input_size).to(device)
torch.onnx.export(model, dummy_input, onnx_path, verbose=True, input_names=["input"], output_names=["output"])

In [17]:
# Verify the ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model is valid.")


ONNX model is valid.


In [18]:
# Use ONNX Runtime for inference
ort_session = ort.InferenceSession(onnx_path)

def predict_onnx(ort_session, image):
    input_tensor = image.astype(np.float32)
    outputs = ort_session.run(None, {"input": input_tensor})
    return outputs[0]

# Example inference
for images, labels in test_loader:
    images_np = images.numpy()
    predictions = []
    for img in images_np:
        img_np = np.expand_dims(img, axis=0)
        preds = predict_onnx(ort_session, img_np)
        predictions.append(np.argmax(preds))
    print(f"Predictions: {predictions}")
    break

Predictions: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [19]:
# Download models for deployment
files.download('/content/resnet101_full_model.pth')
files.download('/content/resnet101_state_dict.pth')
files.download(onnx_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>