Switch to unified view

a b/Train & Test Code/Train.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": [],
7
      "gpuType": "T4"
8
    },
9
    "kernelspec": {
10
      "name": "python3",
11
      "display_name": "Python 3"
12
    },
13
    "language_info": {
14
      "name": "python"
15
    },
16
    "accelerator": "GPU"
17
  },
18
  "cells": [
19
    {
20
      "cell_type": "code",
21
      "execution_count": 1,
22
      "metadata": {
23
        "colab": {
24
          "base_uri": "https://localhost:8080/"
25
        },
26
        "collapsed": true,
27
        "id": "5MvIute5t3tC",
28
        "outputId": "e4ed079c-a1bd-454e-d72e-641c8eaa5cb6"
29
      },
30
      "outputs": [
31
        {
32
          "output_type": "stream",
33
          "name": "stdout",
34
          "text": [
35
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.0/16.0 MB\u001b[0m \u001b[31m102.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
36
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m104.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
37
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m455.8/455.8 kB\u001b[0m \u001b[31m35.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
38
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m60.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
39
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
40
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
41
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
42
            "grpcio-status 1.62.3 requires protobuf>=4.21.6, but you have protobuf 3.20.3 which is incompatible.\u001b[0m\u001b[31m\n",
43
            "\u001b[0m"
44
          ]
45
        }
46
      ],
47
      "source": [
48
        "# Install required libraries\n",
49
        "!pip install onnx onnxruntime tf2onnx --quiet"
50
      ]
51
    },
52
    {
53
      "cell_type": "code",
54
      "source": [
55
        "import torch\n",
56
        "import numpy as np\n",
57
        "import onnx\n",
58
        "import onnxruntime as ort\n",
59
        "from torchvision import datasets, transforms, models\n",
60
        "from torch.utils.data import DataLoader\n",
61
        "from sklearn.metrics import precision_score, recall_score, f1_score\n",
62
        "from google.colab import files"
63
      ],
64
      "metadata": {
65
        "id": "0qNaAqdBus3w"
66
      },
67
      "execution_count": 2,
68
      "outputs": []
69
    },
70
    {
71
      "cell_type": "code",
72
      "source": [
73
        "# Dataset Directories\n",
74
        "train_dir = '/content/drive/MyDrive/update dataset/New Chest(5 class)/Train'\n",
75
        "val_dir = '/content/drive/MyDrive/update dataset/New Chest(5 class)/Val'\n",
76
        "test_dir = '/content/drive/MyDrive/update dataset/New Chest(5 class)/Test'"
77
      ],
78
      "metadata": {
79
        "id": "yAcg0j20us9W"
80
      },
81
      "execution_count": 3,
82
      "outputs": []
83
    },
84
    {
85
      "cell_type": "code",
86
      "source": [
87
        "# Hyperparameters\n",
88
        "num_epochs = 30\n",
89
        "learning_rate = 1e-4\n",
90
        "batch_size = 32\n",
91
        "input_size = 224\n",
92
        "patience = 5\n",
93
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
94
      ],
95
      "metadata": {
96
        "id": "g1QefxvTus_u"
97
      },
98
      "execution_count": 4,
99
      "outputs": []
100
    },
101
    {
102
      "cell_type": "code",
103
      "source": [
104
        "# Data Augmentation and Transformations\n",
105
        "transform_train = transforms.Compose([\n",
106
        "    transforms.Resize((input_size, input_size)),\n",
107
        "    transforms.RandomHorizontalFlip(),\n",
108
        "    transforms.RandomRotation(10),\n",
109
        "    transforms.RandomResizedCrop(input_size),\n",
110
        "    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),\n",
111
        "    transforms.ToTensor(),\n",
112
        "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
113
        "])\n",
114
        "\n",
115
        "transform_val_test = transforms.Compose([\n",
116
        "    transforms.Resize((input_size, input_size)),\n",
117
        "    transforms.ToTensor(),\n",
118
        "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
119
        "])"
120
      ],
121
      "metadata": {
122
        "id": "80Td29j8utCH"
123
      },
124
      "execution_count": 5,
125
      "outputs": []
126
    },
127
    {
128
      "cell_type": "code",
129
      "source": [
130
        "# Datasets\n",
131
        "train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)\n",
132
        "val_dataset = datasets.ImageFolder(val_dir, transform=transform_val_test)\n",
133
        "test_dataset = datasets.ImageFolder(test_dir, transform=transform_val_test)"
134
      ],
135
      "metadata": {
136
        "id": "lBT2g4xxutEe"
137
      },
138
      "execution_count": 6,
139
      "outputs": []
140
    },
141
    {
142
      "cell_type": "code",
143
      "source": [
144
        "# Dataloaders\n",
145
        "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
146
        "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
147
        "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
148
      ],
149
      "metadata": {
150
        "id": "EdGkD6sTutGu"
151
      },
152
      "execution_count": 7,
153
      "outputs": []
154
    },
155
    {
156
      "cell_type": "code",
157
      "source": [
158
        "# Initialize ResNet101 model\n",
159
        "model = models.resnet101(pretrained=True)\n",
160
        "num_features = model.fc.in_features\n",
161
        "model.fc = torch.nn.Linear(num_features, len(train_dataset.classes))\n",
162
        "model = model.to(device)\n"
163
      ],
164
      "metadata": {
165
        "colab": {
166
          "base_uri": "https://localhost:8080/"
167
        },
168
        "id": "d6_-08SWutLe",
169
        "outputId": "b69a12fc-8e1e-4821-d350-53d78923a79b"
170
      },
171
      "execution_count": 8,
172
      "outputs": [
173
        {
174
          "output_type": "stream",
175
          "name": "stderr",
176
          "text": [
177
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
178
            "  warnings.warn(\n",
179
            "/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet101_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet101_Weights.DEFAULT` to get the most up-to-date weights.\n",
180
            "  warnings.warn(msg)\n",
181
            "Downloading: \"https://download.pytorch.org/models/resnet101-63fe2227.pth\" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth\n",
182
            "100%|██████████| 171M/171M [00:01<00:00, 157MB/s]\n"
183
          ]
184
        }
185
      ]
186
    },
187
    {
188
      "cell_type": "code",
189
      "source": [
190
        "# Print class names\n",
191
        "class_names = train_dataset.classes\n",
192
        "print(f\"Classes: {class_names}\")"
193
      ],
194
      "metadata": {
195
        "colab": {
196
          "base_uri": "https://localhost:8080/"
197
        },
198
        "id": "Zkby_asSutOG",
199
        "outputId": "27f39fec-0d0c-4289-b1e8-1dc282f2b082"
200
      },
201
      "execution_count": 9,
202
      "outputs": [
203
        {
204
          "output_type": "stream",
205
          "name": "stdout",
206
          "text": [
207
            "Classes: ['Covid-19', 'Emphysema', 'Healthy', 'Pneumonia', 'Tuberculosis', 'random']\n"
208
          ]
209
        }
210
      ]
211
    },
212
    {
213
      "cell_type": "code",
214
      "source": [
215
        "# Loss Function with Class Weights\n",
216
        "class_counts = [0] * len(train_dataset.classes)\n",
217
        "for _, label in train_dataset:\n",
218
        "    class_counts[label] += 1\n",
219
        "class_weights = torch.tensor(1.0 / np.array(class_counts), dtype=torch.float).to(device)\n",
220
        "criterion = torch.nn.CrossEntropyLoss(weight=class_weights)"
221
      ],
222
      "metadata": {
223
        "id": "uA5JfDwsutQn"
224
      },
225
      "execution_count": 10,
226
      "outputs": []
227
    },
228
    {
229
      "cell_type": "code",
230
      "source": [
231
        "# Optimizer and Scheduler\n",
232
        "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
233
        "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.1, verbose=True)"
234
      ],
235
      "metadata": {
236
        "colab": {
237
          "base_uri": "https://localhost:8080/"
238
        },
239
        "id": "sg69GD-qutVu",
240
        "outputId": "0f3708fb-5838-46b7-c9cd-76807a434eb0"
241
      },
242
      "execution_count": 11,
243
      "outputs": [
244
        {
245
          "output_type": "stream",
246
          "name": "stderr",
247
          "text": [
248
            "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
249
            "  warnings.warn(\n"
250
          ]
251
        }
252
      ]
253
    },
254
    {
255
      "cell_type": "code",
256
      "source": [
257
        "# Training and Validation Loop\n",
258
        "checkpoint_path = \"best_resnet101_model.pth\"\n",
259
        "best_val_accuracy = 0.0\n",
260
        "epochs_no_improve = 0"
261
      ],
262
      "metadata": {
263
        "id": "3LtoehYPutYO"
264
      },
265
      "execution_count": 12,
266
      "outputs": []
267
    },
268
    {
269
      "cell_type": "code",
270
      "source": [
271
        "for epoch in range(num_epochs):\n",
272
        "    # Training\n",
273
        "    model.train()\n",
274
        "    running_loss = 0.0\n",
275
        "    correct_train = 0\n",
276
        "    total_train = 0\n",
277
        "\n",
278
        "    for images, labels in train_loader:\n",
279
        "        images, labels = images.to(device), labels.to(device)\n",
280
        "\n",
281
        "        optimizer.zero_grad()\n",
282
        "        outputs = model(images)\n",
283
        "        loss = criterion(outputs, labels)\n",
284
        "        loss.backward()\n",
285
        "        optimizer.step()\n",
286
        "\n",
287
        "        running_loss += loss.item() * images.size(0)\n",
288
        "        _, preds = torch.max(outputs, 1)\n",
289
        "        correct_train += (preds == labels).sum().item()\n",
290
        "        total_train += labels.size(0)\n",
291
        "\n",
292
        "    # Validation\n",
293
        "    model.eval()\n",
294
        "    val_loss = 0.0\n",
295
        "    correct_val = 0\n",
296
        "    total_val = 0\n",
297
        "    all_preds = []\n",
298
        "    all_labels = []\n",
299
        "\n",
300
        "    with torch.no_grad():\n",
301
        "        for images, labels in val_loader:\n",
302
        "            images, labels = images.to(device), labels.to(device)\n",
303
        "            outputs = model(images)\n",
304
        "            loss = criterion(outputs, labels)\n",
305
        "            val_loss += loss.item() * images.size(0)\n",
306
        "            _, preds = torch.max(outputs, 1)\n",
307
        "            correct_val += (preds == labels).sum().item()\n",
308
        "            total_val += labels.size(0)\n",
309
        "\n",
310
        "            all_preds.extend(preds.cpu().numpy())\n",
311
        "            all_labels.extend(labels.cpu().numpy())\n",
312
        "\n",
313
        "    # Metrics Calculation\n",
314
        "    precision = precision_score(all_labels, all_preds, average='weighted')\n",
315
        "    recall = recall_score(all_labels, all_preds, average='weighted')\n",
316
        "    f1 = f1_score(all_labels, all_preds, average='weighted')\n",
317
        "\n",
318
        "    train_loss = running_loss / total_train\n",
319
        "    val_loss = val_loss / total_val\n",
320
        "    train_accuracy = 100 * correct_train / total_train\n",
321
        "    val_accuracy = 100 * correct_val / total_val\n",
322
        "\n",
323
        "    scheduler.step(val_accuracy)\n",
324
        "\n",
325
        "    # Checkpoint Saving and Early Stopping\n",
326
        "    if val_accuracy > best_val_accuracy:\n",
327
        "        best_val_accuracy = val_accuracy\n",
328
        "        torch.save(model.state_dict(), checkpoint_path)\n",
329
        "        epochs_no_improve = 0\n",
330
        "        print(f\"[INFO] Validation accuracy improved to {val_accuracy:.2f}%, saving model...\")\n",
331
        "    else:\n",
332
        "        epochs_no_improve += 1\n",
333
        "\n",
334
        "    if epochs_no_improve >= patience:\n",
335
        "        print(\"[INFO] Early stopping triggered.\")\n",
336
        "        break\n",
337
        "\n",
338
        "    # Epoch Summary\n",
339
        "    print(f\"Epoch [{epoch+1}/{num_epochs}] \"\n",
340
        "          f\"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, \"\n",
341
        "          f\"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%, \"\n",
342
        "          f\"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}\")"
343
      ],
344
      "metadata": {
345
        "colab": {
346
          "base_uri": "https://localhost:8080/"
347
        },
348
        "id": "ShZd_TVywrAQ",
349
        "outputId": "42729e46-08f3-4799-84fe-2962c95fca4e"
350
      },
351
      "execution_count": 13,
352
      "outputs": [
353
        {
354
          "output_type": "stream",
355
          "name": "stdout",
356
          "text": [
357
            "[INFO] Validation accuracy improved to 85.42%, saving model...\n",
358
            "Epoch [1/30] Train Loss: 0.4114, Train Accuracy: 85.55%, Val Loss: 0.3875, Val Accuracy: 85.42%, Precision: 0.8717, Recall: 0.8542, F1-Score: 0.8530\n",
359
            "[INFO] Validation accuracy improved to 90.08%, saving model...\n",
360
            "Epoch [2/30] Train Loss: 0.2748, Train Accuracy: 90.27%, Val Loss: 0.2860, Val Accuracy: 90.08%, Precision: 0.9102, Recall: 0.9008, F1-Score: 0.9025\n",
361
            "Epoch [3/30] Train Loss: 0.2394, Train Accuracy: 91.41%, Val Loss: 0.3527, Val Accuracy: 87.67%, Precision: 0.8897, Recall: 0.8767, F1-Score: 0.8772\n",
362
            "Epoch [4/30] Train Loss: 0.2209, Train Accuracy: 92.11%, Val Loss: 0.2776, Val Accuracy: 89.58%, Precision: 0.9090, Recall: 0.8958, F1-Score: 0.8965\n",
363
            "[INFO] Validation accuracy improved to 91.83%, saving model...\n",
364
            "Epoch [5/30] Train Loss: 0.1964, Train Accuracy: 93.11%, Val Loss: 0.2259, Val Accuracy: 91.83%, Precision: 0.9243, Recall: 0.9183, F1-Score: 0.9194\n",
365
            "Epoch [6/30] Train Loss: 0.1962, Train Accuracy: 93.02%, Val Loss: 0.2947, Val Accuracy: 90.75%, Precision: 0.9169, Recall: 0.9075, F1-Score: 0.9082\n",
366
            "Epoch [7/30] Train Loss: 0.1872, Train Accuracy: 93.32%, Val Loss: 0.3128, Val Accuracy: 90.00%, Precision: 0.9115, Recall: 0.9000, F1-Score: 0.8995\n",
367
            "Epoch [8/30] Train Loss: 0.1740, Train Accuracy: 93.69%, Val Loss: 0.3236, Val Accuracy: 90.25%, Precision: 0.9173, Recall: 0.9025, F1-Score: 0.9033\n",
368
            "Epoch [9/30] Train Loss: 0.1553, Train Accuracy: 94.46%, Val Loss: 0.2888, Val Accuracy: 90.17%, Precision: 0.9116, Recall: 0.9017, F1-Score: 0.9016\n",
369
            "[INFO] Validation accuracy improved to 92.08%, saving model...\n",
370
            "Epoch [10/30] Train Loss: 0.1164, Train Accuracy: 95.91%, Val Loss: 0.2322, Val Accuracy: 92.08%, Precision: 0.9264, Recall: 0.9208, F1-Score: 0.9211\n",
371
            "[INFO] Validation accuracy improved to 92.58%, saving model...\n",
372
            "Epoch [11/30] Train Loss: 0.0950, Train Accuracy: 96.73%, Val Loss: 0.2173, Val Accuracy: 92.58%, Precision: 0.9320, Recall: 0.9258, F1-Score: 0.9263\n",
373
            "[INFO] Validation accuracy improved to 92.75%, saving model...\n",
374
            "Epoch [12/30] Train Loss: 0.0924, Train Accuracy: 96.67%, Val Loss: 0.2152, Val Accuracy: 92.75%, Precision: 0.9329, Recall: 0.9275, F1-Score: 0.9278\n",
375
            "[INFO] Validation accuracy improved to 93.17%, saving model...\n",
376
            "Epoch [13/30] Train Loss: 0.0840, Train Accuracy: 97.15%, Val Loss: 0.1846, Val Accuracy: 93.17%, Precision: 0.9359, Recall: 0.9317, F1-Score: 0.9322\n",
377
            "[INFO] Validation accuracy improved to 93.50%, saving model...\n",
378
            "Epoch [14/30] Train Loss: 0.0814, Train Accuracy: 97.31%, Val Loss: 0.1971, Val Accuracy: 93.50%, Precision: 0.9388, Recall: 0.9350, F1-Score: 0.9353\n",
379
            "[INFO] Validation accuracy improved to 94.08%, saving model...\n",
380
            "Epoch [15/30] Train Loss: 0.0787, Train Accuracy: 97.36%, Val Loss: 0.1660, Val Accuracy: 94.08%, Precision: 0.9424, Recall: 0.9408, F1-Score: 0.9411\n",
381
            "Epoch [16/30] Train Loss: 0.0752, Train Accuracy: 97.43%, Val Loss: 0.2176, Val Accuracy: 92.67%, Precision: 0.9302, Recall: 0.9267, F1-Score: 0.9272\n",
382
            "Epoch [17/30] Train Loss: 0.0746, Train Accuracy: 97.28%, Val Loss: 0.2189, Val Accuracy: 92.08%, Precision: 0.9267, Recall: 0.9208, F1-Score: 0.9217\n",
383
            "Epoch [18/30] Train Loss: 0.0741, Train Accuracy: 97.50%, Val Loss: 0.2131, Val Accuracy: 92.58%, Precision: 0.9300, Recall: 0.9258, F1-Score: 0.9264\n",
384
            "Epoch [19/30] Train Loss: 0.0746, Train Accuracy: 97.47%, Val Loss: 0.2179, Val Accuracy: 93.33%, Precision: 0.9371, Recall: 0.9333, F1-Score: 0.9336\n",
385
            "[INFO] Early stopping triggered.\n"
386
          ]
387
        }
388
      ]
389
    },
390
    {
391
      "cell_type": "code",
392
      "source": [
393
        "# Load the best model\n",
394
        "model.load_state_dict(torch.load(checkpoint_path))\n",
395
        "model.eval()\n",
396
        "print(\"Best model loaded from checkpoint\")\n"
397
      ],
398
      "metadata": {
399
        "colab": {
400
          "base_uri": "https://localhost:8080/"
401
        },
402
        "id": "iMTEZ4r8wrC6",
403
        "outputId": "c6cf6863-d91c-4a0b-d0ac-bdc7adee899d"
404
      },
405
      "execution_count": 14,
406
      "outputs": [
407
        {
408
          "output_type": "stream",
409
          "name": "stderr",
410
          "text": [
411
            "<ipython-input-14-978f4fcb01ca>:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
412
            "  model.load_state_dict(torch.load(checkpoint_path))\n"
413
          ]
414
        },
415
        {
416
          "output_type": "stream",
417
          "name": "stdout",
418
          "text": [
419
            "Best model loaded from checkpoint\n"
420
          ]
421
        }
422
      ]
423
    },
424
    {
425
      "cell_type": "code",
426
      "source": [
427
        "# Save the PyTorch model and state dictionary\n",
428
        "torch.save(model, '/content/resnet101_full_model.pth')\n",
429
        "torch.save(model.state_dict(), '/content/resnet101_state_dict.pth')\n"
430
      ],
431
      "metadata": {
432
        "id": "hoNq--sXwrFb"
433
      },
434
      "execution_count": 15,
435
      "outputs": []
436
    },
437
    {
438
      "cell_type": "code",
439
      "source": [
440
        "# Convert PyTorch Model to ONNX format\n",
441
        "onnx_path = '/content/resnet101_model.onnx'\n",
442
        "dummy_input = torch.randn(1, 3, input_size, input_size).to(device)\n",
443
        "torch.onnx.export(model, dummy_input, onnx_path, verbose=True, input_names=[\"input\"], output_names=[\"output\"])"
444
      ],
445
      "metadata": {
446
        "id": "_V1uY9wvwrHp"
447
      },
448
      "execution_count": 16,
449
      "outputs": []
450
    },
451
    {
452
      "cell_type": "code",
453
      "source": [
454
        "# Verify the ONNX model\n",
455
        "onnx_model = onnx.load(onnx_path)\n",
456
        "onnx.checker.check_model(onnx_model)\n",
457
        "print(\"ONNX model is valid.\")\n"
458
      ],
459
      "metadata": {
460
        "colab": {
461
          "base_uri": "https://localhost:8080/"
462
        },
463
        "id": "IlbAheXlwrKB",
464
        "outputId": "16c0aa95-2e75-432b-d9ce-5238fce92744"
465
      },
466
      "execution_count": 17,
467
      "outputs": [
468
        {
469
          "output_type": "stream",
470
          "name": "stdout",
471
          "text": [
472
            "ONNX model is valid.\n"
473
          ]
474
        }
475
      ]
476
    },
477
    {
478
      "cell_type": "code",
479
      "source": [
480
        "# Use ONNX Runtime for inference\n",
481
        "ort_session = ort.InferenceSession(onnx_path)\n",
482
        "\n",
483
        "def predict_onnx(ort_session, image):\n",
484
        "    input_tensor = image.astype(np.float32)\n",
485
        "    outputs = ort_session.run(None, {\"input\": input_tensor})\n",
486
        "    return outputs[0]\n",
487
        "\n",
488
        "# Example inference\n",
489
        "for images, labels in test_loader:\n",
490
        "    images_np = images.numpy()\n",
491
        "    predictions = []\n",
492
        "    for img in images_np:\n",
493
        "        img_np = np.expand_dims(img, axis=0)\n",
494
        "        preds = predict_onnx(ort_session, img_np)\n",
495
        "        predictions.append(np.argmax(preds))\n",
496
        "    print(f\"Predictions: {predictions}\")\n",
497
        "    break"
498
      ],
499
      "metadata": {
500
        "colab": {
501
          "base_uri": "https://localhost:8080/"
502
        },
503
        "id": "MdRS709s3Eb4",
504
        "outputId": "5aa03a76-a2bf-4d90-fa15-2435ff93544f"
505
      },
506
      "execution_count": 18,
507
      "outputs": [
508
        {
509
          "output_type": "stream",
510
          "name": "stdout",
511
          "text": [
512
            "Predictions: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
513
          ]
514
        }
515
      ]
516
    },
517
    {
518
      "cell_type": "code",
519
      "source": [
520
        "# Download models for deployment\n",
521
        "files.download('/content/resnet101_full_model.pth')\n",
522
        "files.download('/content/resnet101_state_dict.pth')\n",
523
        "files.download(onnx_path)"
524
      ],
525
      "metadata": {
526
        "colab": {
527
          "base_uri": "https://localhost:8080/",
528
          "height": 17
529
        },
530
        "id": "Tu1q_dy8wrVw",
531
        "outputId": "398cb159-62ee-4800-8878-b0c2a7e9bd37"
532
      },
533
      "execution_count": 19,
534
      "outputs": [
535
        {
536
          "data": {
537
            "application/javascript": [
538
              "\n",
539
              "    async function download(id, filename, size) {\n",
540
              "      if (!google.colab.kernel.accessAllowed) {\n",
541
              "        return;\n",
542
              "      }\n",
543
              "      const div = document.createElement('div');\n",
544
              "      const label = document.createElement('label');\n",
545
              "      label.textContent = `Downloading \"${filename}\": `;\n",
546
              "      div.appendChild(label);\n",
547
              "      const progress = document.createElement('progress');\n",
548
              "      progress.max = size;\n",
549
              "      div.appendChild(progress);\n",
550
              "      document.body.appendChild(div);\n",
551
              "\n",
552
              "      const buffers = [];\n",
553
              "      let downloaded = 0;\n",
554
              "\n",
555
              "      const channel = await google.colab.kernel.comms.open(id);\n",
556
              "      // Send a message to notify the kernel that we're ready.\n",
557
              "      channel.send({})\n",
558
              "\n",
559
              "      for await (const message of channel.messages) {\n",
560
              "        // Send a message to notify the kernel that we're ready.\n",
561
              "        channel.send({})\n",
562
              "        if (message.buffers) {\n",
563
              "          for (const buffer of message.buffers) {\n",
564
              "            buffers.push(buffer);\n",
565
              "            downloaded += buffer.byteLength;\n",
566
              "            progress.value = downloaded;\n",
567
              "          }\n",
568
              "        }\n",
569
              "      }\n",
570
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
571
              "      const a = document.createElement('a');\n",
572
              "      a.href = window.URL.createObjectURL(blob);\n",
573
              "      a.download = filename;\n",
574
              "      div.appendChild(a);\n",
575
              "      a.click();\n",
576
              "      div.remove();\n",
577
              "    }\n",
578
              "  "
579
            ],
580
            "text/plain": [
581
              "<IPython.core.display.Javascript object>"
582
            ]
583
          },
584
          "metadata": {},
585
          "output_type": "display_data"
586
        },
587
        {
588
          "data": {
589
            "application/javascript": [
590
              "download(\"download_2676e401-61a4-436d-ac9c-a1fc4cb098c4\", \"resnet101_full_model.pth\", 170735052)"
591
            ],
592
            "text/plain": [
593
              "<IPython.core.display.Javascript object>"
594
            ]
595
          },
596
          "metadata": {},
597
          "output_type": "display_data"
598
        },
599
        {
600
          "data": {
601
            "application/javascript": [
602
              "\n",
603
              "    async function download(id, filename, size) {\n",
604
              "      if (!google.colab.kernel.accessAllowed) {\n",
605
              "        return;\n",
606
              "      }\n",
607
              "      const div = document.createElement('div');\n",
608
              "      const label = document.createElement('label');\n",
609
              "      label.textContent = `Downloading \"${filename}\": `;\n",
610
              "      div.appendChild(label);\n",
611
              "      const progress = document.createElement('progress');\n",
612
              "      progress.max = size;\n",
613
              "      div.appendChild(progress);\n",
614
              "      document.body.appendChild(div);\n",
615
              "\n",
616
              "      const buffers = [];\n",
617
              "      let downloaded = 0;\n",
618
              "\n",
619
              "      const channel = await google.colab.kernel.comms.open(id);\n",
620
              "      // Send a message to notify the kernel that we're ready.\n",
621
              "      channel.send({})\n",
622
              "\n",
623
              "      for await (const message of channel.messages) {\n",
624
              "        // Send a message to notify the kernel that we're ready.\n",
625
              "        channel.send({})\n",
626
              "        if (message.buffers) {\n",
627
              "          for (const buffer of message.buffers) {\n",
628
              "            buffers.push(buffer);\n",
629
              "            downloaded += buffer.byteLength;\n",
630
              "            progress.value = downloaded;\n",
631
              "          }\n",
632
              "        }\n",
633
              "      }\n",
634
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
635
              "      const a = document.createElement('a');\n",
636
              "      a.href = window.URL.createObjectURL(blob);\n",
637
              "      a.download = filename;\n",
638
              "      div.appendChild(a);\n",
639
              "      a.click();\n",
640
              "      div.remove();\n",
641
              "    }\n",
642
              "  "
643
            ],
644
            "text/plain": [
645
              "<IPython.core.display.Javascript object>"
646
            ]
647
          },
648
          "metadata": {},
649
          "output_type": "display_data"
650
        },
651
        {
652
          "data": {
653
            "application/javascript": [
654
              "download(\"download_4a55caaf-b2c4-4798-a186-8a4c0958245b\", \"resnet101_state_dict.pth\", 170685004)"
655
            ],
656
            "text/plain": [
657
              "<IPython.core.display.Javascript object>"
658
            ]
659
          },
660
          "metadata": {},
661
          "output_type": "display_data"
662
        },
663
        {
664
          "data": {
665
            "application/javascript": [
666
              "\n",
667
              "    async function download(id, filename, size) {\n",
668
              "      if (!google.colab.kernel.accessAllowed) {\n",
669
              "        return;\n",
670
              "      }\n",
671
              "      const div = document.createElement('div');\n",
672
              "      const label = document.createElement('label');\n",
673
              "      label.textContent = `Downloading \"${filename}\": `;\n",
674
              "      div.appendChild(label);\n",
675
              "      const progress = document.createElement('progress');\n",
676
              "      progress.max = size;\n",
677
              "      div.appendChild(progress);\n",
678
              "      document.body.appendChild(div);\n",
679
              "\n",
680
              "      const buffers = [];\n",
681
              "      let downloaded = 0;\n",
682
              "\n",
683
              "      const channel = await google.colab.kernel.comms.open(id);\n",
684
              "      // Send a message to notify the kernel that we're ready.\n",
685
              "      channel.send({})\n",
686
              "\n",
687
              "      for await (const message of channel.messages) {\n",
688
              "        // Send a message to notify the kernel that we're ready.\n",
689
              "        channel.send({})\n",
690
              "        if (message.buffers) {\n",
691
              "          for (const buffer of message.buffers) {\n",
692
              "            buffers.push(buffer);\n",
693
              "            downloaded += buffer.byteLength;\n",
694
              "            progress.value = downloaded;\n",
695
              "          }\n",
696
              "        }\n",
697
              "      }\n",
698
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
699
              "      const a = document.createElement('a');\n",
700
              "      a.href = window.URL.createObjectURL(blob);\n",
701
              "      a.download = filename;\n",
702
              "      div.appendChild(a);\n",
703
              "      a.click();\n",
704
              "      div.remove();\n",
705
              "    }\n",
706
              "  "
707
            ],
708
            "text/plain": [
709
              "<IPython.core.display.Javascript object>"
710
            ]
711
          },
712
          "metadata": {},
713
          "output_type": "display_data"
714
        },
715
        {
716
          "data": {
717
            "application/javascript": [
718
              "download(\"download_84479e94-0d0c-41e5-bc32-81e18fc96116\", \"resnet101_model.onnx\", 171040381)"
719
            ],
720
            "text/plain": [
721
              "<IPython.core.display.Javascript object>"
722
            ]
723
          },
724
          "metadata": {},
725
          "output_type": "display_data"
726
        }
727
      ]
728
    }
729
  ]
730
}