[bb64db]: / Generation / Reconstruction_Metrics_ATM.ipynb

Download this file

555 lines (554 with data), 19.1 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53146de1-97c4-4bf9-9569-a6a400d377fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import make_grid\n",
    "from tqdm import tqdm\n",
    "from datetime import datetime\n",
    "import argparse\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "local_rank = 0\n",
    "print(\"device:\",device)\n",
    "\n",
    "import utils\n",
    "seed=42\n",
    "utils.seed_everything(seed=seed)\n",
    "\n",
    "if utils.is_interactive():\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2\n",
    "\n",
    "# from models import Clipper\n",
    "# clip_extractor = Clipper(\"ViT-L/14\", hidden_state=False, norm_embs=True, device=device)\n",
    "imsize = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c72944c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "# Define the source and target directories\n",
    "source_dir = '/home/ldy/Workspace/Generation/generated_imgs/sub-01'\n",
    "target_dir = '/home/ldy/Workspace/Generation/generated_imgs_tensor'\n",
    "\n",
    "# Create the target directory if it doesn't exist\n",
    "if not os.path.exists(target_dir):\n",
    "    os.makedirs(target_dir)\n",
    "\n",
    "# Initialize a list to hold all the image tensors\n",
    "tensor_list = []\n",
    "\n",
    "# Iterate over the folders in the source directory\n",
    "for folder_name in os.listdir(source_dir):\n",
    "    folder_path = os.path.join(source_dir, folder_name)\n",
    "    \n",
    "    # Check if it's a directory\n",
    "    if os.path.isdir(folder_path):\n",
    "        # Iterate over the images in the folder\n",
    "        for image_name in os.listdir(folder_path):\n",
    "            image_path = os.path.join(folder_path, image_name)\n",
    "            \n",
    "            # Load the image\n",
    "            with Image.open(image_path) as img:\n",
    "                # Convert the image to a PyTorch tensor and add a batch dimension\n",
    "                tensor = torch.tensor(np.array(img)).unsqueeze(0)\n",
    "                tensor_list.append(tensor)\n",
    "\n",
    "# Concatenate all tensors along the 0th dimension\n",
    "all_tensors = torch.cat(tensor_list, dim=0)\n",
    "\n",
    "# Save the combined tensor\n",
    "combined_tensor_path = os.path.join(target_dir, \"all_images.pt\")\n",
    "torch.save(all_tensors, combined_tensor_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc126d23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from PIL import Image\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "# Define the source and target directories\n",
    "source_dir = '/home/ldy/Workspace/THINGS/images_set/test_images'\n",
    "target_dir = '/home/ldy/Workspace/THINGS/images_set/test_images_tensor'\n",
    "\n",
    "# Create the target directory if it doesn't exist\n",
    "if not os.path.exists(target_dir):\n",
    "    os.makedirs(target_dir)\n",
    "\n",
    "# Initialize a list to hold all the image tensors\n",
    "tensor_list = []\n",
    "\n",
    "# Iterate over the folders in the source directory\n",
    "for folder_name in os.listdir(source_dir):\n",
    "    folder_path = os.path.join(source_dir, folder_name)\n",
    "    \n",
    "    # Check if it's a directory\n",
    "    if os.path.isdir(folder_path):\n",
    "        # Iterate over the images in the folder\n",
    "        for image_name in os.listdir(folder_path):\n",
    "            image_path = os.path.join(folder_path, image_name)\n",
    "            \n",
    "            # Load the image\n",
    "            with Image.open(image_path) as img:\n",
    "                # Convert the image to a PyTorch tensor and add a batch dimension\n",
    "                tensor = torch.tensor(np.array(img)).unsqueeze(0)\n",
    "                tensor_list.append(tensor)\n",
    "\n",
    "# Concatenate all tensors along the 0th dimension\n",
    "all_tensors = torch.cat(tensor_list, dim=0)\n",
    "\n",
    "# Save the combined tensor\n",
    "combined_tensor_path = os.path.join(target_dir, \"all_images.pt\")\n",
    "torch.save(all_tensors, combined_tensor_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf904985-a9ed-4d12-a965-a46db5473d1b",
   "metadata": {},
   "source": [
    "# Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adb10ae2-0188-454a-b227-3a90783bdaa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "recon_path = '/home/ldy/Workspace/Generation/generated_imgs_tensor/all_images.pt'\n",
    "all_images_path = '/home/ldy/Workspace/THINGS/images_set/test_images_tensor/all_images.pt'\n",
    "all_brain_recons = torch.load(f'{recon_path}')\n",
    "all_images = torch.load(f'{all_images_path}')\n",
    "\n",
    "print(all_images.shape)\n",
    "print(all_brain_recons.shape)\n",
    "\n",
    "all_images = all_images.to(device)\n",
    "all_brain_recons = all_brain_recons.to(device).to(all_images.dtype).clamp(0,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbf9b290-4016-4d33-823b-e0624a68b700",
   "metadata": {},
   "source": [
    "# Display reconstructions next to ground truth images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db613b6a-33f1-459c-b494-52328297b7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "imsize = 256\n",
    "all_images = transforms.Resize((imsize,imsize))(all_images)\n",
    "all_brain_recons = transforms.Resize((imsize,imsize))(all_brain_recons)\n",
    "\n",
    "np.random.seed(0)\n",
    "ind = np.flip(np.array([112,119,101,44,159,22,173,174,175,189,981,243,249,255,265]))\n",
    "\n",
    "all_interleaved = torch.zeros(len(ind)*2,3,imsize,imsize)\n",
    "icount = 0\n",
    "for t in ind:\n",
    "    all_interleaved[icount] = all_images[t]\n",
    "    all_interleaved[icount+1] = all_brain_recons[t]\n",
    "    icount += 2\n",
    "\n",
    "plt.rcParams[\"savefig.bbox\"] = 'tight'\n",
    "def show(imgs,figsize):\n",
    "    if not isinstance(imgs, list):\n",
    "        imgs = [imgs]\n",
    "    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)\n",
    "    for i, img in enumerate(imgs):\n",
    "        img = img.detach()\n",
    "        img = transforms.ToPILImage()(img)\n",
    "        axs[0, i].imshow(np.asarray(img))\n",
    "        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "    \n",
    "grid = make_grid(all_interleaved, nrow=10, padding=2)\n",
    "show(grid,figsize=(20,16))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be7ba39c-cd18-4942-a86a-ee640865bcdc",
   "metadata": {},
   "source": [
    "# 2-Way Identification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49f8056e-a80c-4abc-8c2d-3193bb43f945",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names\n",
    "\n",
    "@torch.no_grad()\n",
    "def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True):\n",
    "    preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))\n",
    "    reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))\n",
    "    if feature_layer is None:\n",
    "        preds = preds.float().flatten(1).cpu().numpy()\n",
    "        reals = reals.float().flatten(1).cpu().numpy()\n",
    "    else:\n",
    "        preds = preds[feature_layer].float().flatten(1).cpu().numpy()\n",
    "        reals = reals[feature_layer].float().flatten(1).cpu().numpy()\n",
    "\n",
    "    r = np.corrcoef(reals, preds)\n",
    "    r = r[:len(all_images), len(all_images):]\n",
    "    congruents = np.diag(r)\n",
    "\n",
    "    success = r < congruents\n",
    "    success_cnt = np.sum(success, 0)\n",
    "\n",
    "    if return_avg:\n",
    "        perf = np.mean(success_cnt) / (len(all_images)-1)\n",
    "        return perf\n",
    "    else:\n",
    "        return success_cnt, len(all_images)-1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63801703-0f78-47ca-89c6-54dcf71156a4",
   "metadata": {},
   "source": [
    "## PixCorr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d89db923-16c8-4ec8-9bf9-cc9749031188",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "])\n",
    "\n",
    "# Flatten images while keeping the batch dimension\n",
    "all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()\n",
    "all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()\n",
    "\n",
    "print(all_images_flattened.shape)\n",
    "print(all_brain_recons_flattened.shape)\n",
    "\n",
    "corrsum = 0\n",
    "for i in tqdm(range(982)):\n",
    "    corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]\n",
    "corrmean = corrsum / 982\n",
    "\n",
    "pixcorr = corrmean\n",
    "print(pixcorr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2740d15b-1a02-4d50-9fe4-9dbfae38800e",
   "metadata": {},
   "source": [
    "## SSIM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a712ab76-1c8f-4232-995c-d06f8191d9b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# see https://github.com/zijin-gu/meshconv-decoding/issues/3\n",
    "from skimage.color import rgb2gray\n",
    "from skimage.metrics import structural_similarity as ssim\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), \n",
    "])\n",
    "\n",
    "# convert image to grayscale with rgb2grey\n",
    "img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())\n",
    "recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu())\n",
    "print(\"converted, now calculating ssim...\")\n",
    "\n",
    "ssim_score=[]\n",
    "for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):\n",
    "    ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))\n",
    "\n",
    "ssim = np.mean(ssim_score)\n",
    "print(ssim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4554b22-7faa-4e59-a6db-83bf775dd9e4",
   "metadata": {
    "tags": []
   },
   "source": [
    "### AlexNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc906db-bc5f-4cce-87f4-878706c14d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import alexnet, AlexNet_Weights\n",
    "alex_weights = AlexNet_Weights.IMAGENET1K_V1\n",
    "\n",
    "alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)\n",
    "alex_model.eval().requires_grad_(False)\n",
    "\n",
    "# see alex_weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "layer = 'early, AlexNet(2)'\n",
    "print(f\"\\n---{layer}---\")\n",
    "all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, \n",
    "                                                          alex_model, preprocess, 'features.4')\n",
    "alexnet2 = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {alexnet2:.4f}\")\n",
    "\n",
    "layer = 'mid, AlexNet(5)'\n",
    "print(f\"\\n---{layer}---\")\n",
    "all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, \n",
    "                                                          alex_model, preprocess, 'features.11')\n",
    "alexnet5 = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {alexnet5:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3a8ac02-a024-442f-a0a9-638a8afdeb0d",
   "metadata": {},
   "source": [
    "### InceptionV3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "555b088c-07b4-4164-9e38-bdb5d43d58ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import inception_v3, Inception_V3_Weights\n",
    "weights = Inception_V3_Weights.DEFAULT\n",
    "inception_model = create_feature_extractor(inception_v3(weights=weights), \n",
    "                                           return_nodes=['avgpool']).to(device)\n",
    "inception_model.eval().requires_grad_(False)\n",
    "\n",
    "# see weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "all_per_correct = two_way_identification(all_brain_recons, all_images,\n",
    "                                        inception_model, preprocess, 'avgpool')\n",
    "        \n",
    "inception = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {inception:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12537959-3045-46f2-833f-80ff6327d40a",
   "metadata": {},
   "source": [
    "### CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc928ed-952c-4b2f-ae05-2cd8b8c770c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import clip\n",
    "clip_model, preprocess = clip.load(\"ViT-L/14\", device=device)\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
    "                         std=[0.26862954, 0.26130258, 0.27577711]),\n",
    "])\n",
    "\n",
    "all_per_correct = two_way_identification(all_brain_recons, all_images,\n",
    "                                        clip_model.encode_image, preprocess, None) # final layer\n",
    "clip_ = np.mean(all_per_correct)\n",
    "print(f\"2-way Percent Correct: {clip_:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e00a5a08-8acb-4fd9-b581-697df7b845db",
   "metadata": {},
   "source": [
    "### Efficient Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d103b5af-ae7c-4f07-94e0-6b35d951ce78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights\n",
    "weights = EfficientNet_B1_Weights.DEFAULT\n",
    "eff_model = create_feature_extractor(efficientnet_b1(weights=weights), \n",
    "                                    return_nodes=['avgpool']).to(device)\n",
    "eff_model.eval().requires_grad_(False)\n",
    "\n",
    "# see weights.transforms()\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "gt = eff_model(preprocess(all_images))['avgpool']\n",
    "gt = gt.reshape(len(gt),-1).cpu().numpy()\n",
    "fake = eff_model(preprocess(all_brain_recons))['avgpool']\n",
    "fake = fake.reshape(len(fake),-1).cpu().numpy()\n",
    "\n",
    "effnet = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()\n",
    "print(\"Distance:\",effnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81c764bd-4c46-4102-9342-5d63bc47c4e4",
   "metadata": {},
   "source": [
    "### SwAV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9eacb33d-06b7-4847-b106-3a8ae15798f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')\n",
    "swav_model = create_feature_extractor(swav_model, \n",
    "                                    return_nodes=['avgpool']).to(device)\n",
    "swav_model.eval().requires_grad_(False)\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                         std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "gt = swav_model(preprocess(all_images))['avgpool']\n",
    "gt = gt.reshape(len(gt),-1).cpu().numpy()\n",
    "fake = swav_model(preprocess(all_brain_recons))['avgpool']\n",
    "fake = fake.reshape(len(fake),-1).cpu().numpy()\n",
    "\n",
    "swav = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()\n",
    "print(\"Distance:\",swav)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c8674ca-b936-41c0-bba3-ec87a94a5deb",
   "metadata": {},
   "source": [
    "# Display in table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02dbee6a-56fc-4bd7-9d52-4413c7a01a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a dictionary to store variable names and their corresponding values\n",
    "data = {\n",
    "    \"Metric\": [\"PixCorr\", \"SSIM\", \"AlexNet(2)\", \"AlexNet(5)\", \"InceptionV3\", \"CLIP\", \"EffNet-B\", \"SwAV\"],\n",
    "    \"Value\": [pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav],\n",
    "}\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "print(df.to_string(index=False))\n",
    "\n",
    "if not utils.is_interactive():\n",
    "    # save table to txt file\n",
    "    df.to_csv(f'{recon_path[:-3]}.csv', sep='\\t', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}