[44bf8c]: / Notebooks / old / metric_IOU.ipynb

Download this file

152 lines (152 with data), 9.4 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Untitled2.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "uGLc0edWLE5B",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6dd94c94-b3aa-463f-bf47-bc4b92d2f528"
      },
      "source": [
        "!pip install pytorch-lightning\r\n",
        "!pip install torch"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pytorch-lightning in /usr/local/lib/python3.6/dist-packages (1.1.2)\n",
            "Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (2.4.0)\n",
            "Requirement already satisfied: torch>=1.3 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (1.7.0+cu101)\n",
            "Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (5.3.1)\n",
            "Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (0.18.2)\n",
            "Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (1.19.4)\n",
            "Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (4.41.1)\n",
            "Requirement already satisfied: fsspec>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning) (0.8.5)\n",
            "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.10.0)\n",
            "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.32.0)\n",
            "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)\n",
            "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.2)\n",
            "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.15.0)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.3)\n",
            "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.7.0)\n",
            "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (51.0.0)\n",
            "Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.36.2)\n",
            "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (2.23.0)\n",
            "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.12.4)\n",
            "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.17.2)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.3->pytorch-lightning) (3.7.4.3)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.3->pytorch-lightning) (0.8)\n",
            "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.0)\n",
            "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.3.0)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (2.10)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch-lightning) (2020.12.5)\n",
            "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.0)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)\n",
            "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.6)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.1.0)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.4.0)\n",
            "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch) (0.8)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch) (3.7.4.3)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.19.4)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.18.2)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bI1Wf7bTSzhd"
      },
      "source": [
        "import torch\r\n",
        "import pytorch_lightning as pl\r\n",
        "import numpy as np\r\n",
        "\r\n",
        "\r\n",
        "from pytorch_lightning.metrics import Metric\r\n",
        "\r\n",
        "class IntersectionOverUnion(Metric):\r\n",
        "    def __init__(self, dist_sync_on_step=False):\r\n",
        "        super().__init__(dist_sync_on_step=dist_sync_on_step)\r\n",
        "\r\n",
        "        self.add_state(\"intersection\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\r\n",
        "        self.add_state(\"union\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\r\n",
        "\r\n",
        "    def update(self, preds, target):\r\n",
        "        self.preds = preds\r\n",
        "        self.target = target\r\n",
        "        assert preds.shape == target.shape\r\n",
        "\r\n",
        "        self.intersection = np.logical_and(target,preds)\r\n",
        "        self.union = np.logical_or(target,preds)\r\n",
        "        self.iou_score = np.sum(self.intersection) / np.sum(self.union)\r\n",
        "\r\n",
        "    def compute(self):\r\n",
        "        return self.iou_score\r\n"
      ],
      "execution_count": 117,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jxubOFn42cM-",
        "outputId": "d70c5df4-0dd1-4777-a0a7-5a008b089ddb"
      },
      "source": [
        "metric = IntersectionOverUnion()\r\n",
        "\r\n",
        "array_1 = np.array ([True, True, True, False, True])\r\n",
        "array_2 = np.array ([True, False, True, False, True])\r\n",
        "\r\n",
        "metric.update(array_1,array_2)\r\n",
        "IOU_score = metric.compute()\r\n",
        "print(IOU_score)"
      ],
      "execution_count": 116,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0.75\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GWek8JBO6bqJ"
      },
      "source": [
        ""
      ]
    }
  ]
}