152 lines (152 with data), 9.4 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Untitled2.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "uGLc0edWLE5B",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "6dd94c94-b3aa-463f-bf47-bc4b92d2f528"
},
"source": [
"!pip install pytorch-lightning\r\n",
"!pip install torch"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: pytorch-lightning in /usr/local/lib/python3.6/dist-packages (1.1.2)\n",
"Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (2.4.0)\n",
"Requirement already satisfied: torch>=1.3 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (1.7.0+cu101)\n",
"Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (5.3.1)\n",
"Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (0.18.2)\n",
"Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (1.19.4)\n",
"Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (4.41.1)\n",
"Requirement already satisfied: fsspec>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (0.8.5)\n",
"Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.10.0)\n",
"Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.32.0)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)\n",
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.2)\n",
"Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.15.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.3)\n",
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.7.0)\n",
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (51.0.0)\n",
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.36.2)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (2.23.0)\n",
"Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.12.4)\n",
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.17.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.3->pytorch-lightning) (3.7.4.3)\n",
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.3->pytorch-lightning) (0.8)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.0)\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.3.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (2020.12.5)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.0)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)\n",
"Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.6)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.1.0)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.4.0)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)\n",
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch) (0.8)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch) (3.7.4.3)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.19.4)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.18.2)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bI1Wf7bTSzhd"
},
"source": [
"import torch\r\n",
"import pytorch_lightning as pl\r\n",
"import numpy as np\r\n",
"\r\n",
"\r\n",
"from pytorch_lightning.metrics import Metric\r\n",
"\r\n",
"class IntersectionOverUnion(Metric):\r\n",
" def __init__(self, dist_sync_on_step=False):\r\n",
" super().__init__(dist_sync_on_step=dist_sync_on_step)\r\n",
"\r\n",
" self.add_state(\"intersection\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\r\n",
" self.add_state(\"union\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\r\n",
"\r\n",
" def update(self, preds, target):\r\n",
" self.preds = preds\r\n",
" self.target = target\r\n",
" assert preds.shape == target.shape\r\n",
"\r\n",
" self.intersection = np.logical_and(target,preds)\r\n",
" self.union = np.logical_or(target,preds)\r\n",
" self.iou_score = np.sum(self.intersection) / np.sum(self.union)\r\n",
"\r\n",
" def compute(self):\r\n",
" return self.iou_score\r\n"
],
"execution_count": 117,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jxubOFn42cM-",
"outputId": "d70c5df4-0dd1-4777-a0a7-5a008b089ddb"
},
"source": [
"metric = IntersectionOverUnion()\r\n",
"\r\n",
"array_1 = np.array ([True, True, True, False, True])\r\n",
"array_2 = np.array ([True, False, True, False, True])\r\n",
"\r\n",
"metric.update(array_1,array_2)\r\n",
"IOU_score = metric.compute()\r\n",
"print(IOU_score)"
],
"execution_count": 116,
"outputs": [
{
"output_type": "stream",
"text": [
"0.75\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GWek8JBO6bqJ"
},
"source": [
""
]
}
]
}