[bb64db]: / Generation / EEGNetV4_Generation_metrics_sub8.ipynb

Download this file

5402 lines (5401 with data), 134.8 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import numpy as np\n",
    "import os\n",
    "import clip\n",
    "from torch.nn import functional as F\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "train = False\n",
    "classes = None\n",
    "pictures= None\n",
    "\n",
    "def load_data():\n",
    "    data_list = []\n",
    "    label_list = []\n",
    "    texts = []\n",
    "    images = []\n",
    "    \n",
    "    if train:\n",
    "        text_directory = \"/home/ldy/Workspace/THINGS/images_set/training_images\"  \n",
    "    else:\n",
    "        text_directory = \"/home/ldy/Workspace/THINGS/images_set/test_images\"\n",
    "\n",
    "    dirnames = [d for d in os.listdir(text_directory) if os.path.isdir(os.path.join(text_directory, d))]\n",
    "    dirnames.sort()\n",
    "    \n",
    "    if classes is not None:\n",
    "        dirnames = [dirnames[i] for i in classes]\n",
    "\n",
    "    for dir in dirnames:\n",
    "\n",
    "        try:\n",
    "            idx = dir.index('_')\n",
    "            description = dir[idx+1:]\n",
    "        except ValueError:\n",
    "            print(f\"Skipped: {dir} due to no '_' found.\")\n",
    "            continue\n",
    "            \n",
    "        new_description = f\"{description}\"\n",
    "        texts.append(new_description)\n",
    "\n",
    "    if train:\n",
    "        img_directory = \"/home/ldy/Workspace/THINGS/images_set/training_images\"\n",
    "    else:\n",
    "        img_directory =\"/home/ldy/Workspace/THINGS/images_set/test_images\"\n",
    "    \n",
    "    all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]\n",
    "    all_folders.sort()\n",
    "\n",
    "    if classes is not None and pictures is not None:\n",
    "        images = []\n",
    "        for i in range(len(classes)):\n",
    "            class_idx = classes[i]\n",
    "            pic_idx = pictures[i]\n",
    "            if class_idx < len(all_folders):\n",
    "                folder = all_folders[class_idx]\n",
    "                folder_path = os.path.join(img_directory, folder)\n",
    "                all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "                all_images.sort()\n",
    "                if pic_idx < len(all_images):\n",
    "                    images.append(os.path.join(folder_path, all_images[pic_idx]))\n",
    "    elif classes is not None and pictures is None:\n",
    "        images = []\n",
    "        for i in range(len(classes)):\n",
    "            class_idx = classes[i]\n",
    "            if class_idx < len(all_folders):\n",
    "                folder = all_folders[class_idx]\n",
    "                folder_path = os.path.join(img_directory, folder)\n",
    "                all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "                all_images.sort()\n",
    "                images.extend(os.path.join(folder_path, img) for img in all_images)\n",
    "    elif classes is None:\n",
    "        images = []\n",
    "        for folder in all_folders:\n",
    "            folder_path = os.path.join(img_directory, folder)\n",
    "            all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "            all_images.sort()  \n",
    "            images.extend(os.path.join(folder_path, img) for img in all_images)\n",
    "    else:\n",
    "\n",
    "        print(\"Error\")\n",
    "    return texts, images\n",
    "texts, images = load_data()\n",
    "# images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['aircraft_carrier',\n",
       " 'antelope',\n",
       " 'backscratcher',\n",
       " 'balance_beam',\n",
       " 'banana',\n",
       " 'baseball_bat',\n",
       " 'basil',\n",
       " 'basketball',\n",
       " 'bassoon',\n",
       " 'baton4',\n",
       " 'batter',\n",
       " 'beaver',\n",
       " 'bench',\n",
       " 'bike',\n",
       " 'birthday_cake',\n",
       " 'blowtorch',\n",
       " 'boat',\n",
       " 'bok_choy',\n",
       " 'bonnet',\n",
       " 'bottle_opener',\n",
       " 'brace',\n",
       " 'bread',\n",
       " 'breadbox',\n",
       " 'bug',\n",
       " 'buggy',\n",
       " 'bullet',\n",
       " 'bun',\n",
       " 'bush',\n",
       " 'calamari',\n",
       " 'candlestick',\n",
       " 'cart',\n",
       " 'cashew',\n",
       " 'cat',\n",
       " 'caterpillar',\n",
       " 'cd_player',\n",
       " 'chain',\n",
       " 'chaps',\n",
       " 'cheese',\n",
       " 'cheetah',\n",
       " 'chest2',\n",
       " 'chime',\n",
       " 'chopsticks',\n",
       " 'cleat',\n",
       " 'cleaver',\n",
       " 'coat',\n",
       " 'cobra',\n",
       " 'coconut',\n",
       " 'coffee_bean',\n",
       " 'coffeemaker',\n",
       " 'cookie',\n",
       " 'cordon_bleu',\n",
       " 'coverall',\n",
       " 'crab',\n",
       " 'creme_brulee',\n",
       " 'crepe',\n",
       " 'crib',\n",
       " 'croissant',\n",
       " 'crow',\n",
       " 'cruise_ship',\n",
       " 'crumb',\n",
       " 'cupcake',\n",
       " 'dagger',\n",
       " 'dalmatian',\n",
       " 'dessert',\n",
       " 'dragonfly',\n",
       " 'dreidel',\n",
       " 'drum',\n",
       " 'duffel_bag',\n",
       " 'eagle',\n",
       " 'eel',\n",
       " 'egg',\n",
       " 'elephant',\n",
       " 'espresso',\n",
       " 'face_mask',\n",
       " 'ferry',\n",
       " 'flamingo',\n",
       " 'folder',\n",
       " 'fork',\n",
       " 'freezer',\n",
       " 'french_horn',\n",
       " 'fruit',\n",
       " 'garlic',\n",
       " 'glove',\n",
       " 'golf_cart',\n",
       " 'gondola',\n",
       " 'goose',\n",
       " 'gopher',\n",
       " 'gorilla',\n",
       " 'grasshopper',\n",
       " 'grenade',\n",
       " 'hamburger',\n",
       " 'hammer',\n",
       " 'handbrake',\n",
       " 'headscarf',\n",
       " 'highchair',\n",
       " 'hoodie',\n",
       " 'hummingbird',\n",
       " 'ice_cube',\n",
       " 'ice_pack',\n",
       " 'jeep',\n",
       " 'jelly_bean',\n",
       " 'jukebox',\n",
       " 'kettle',\n",
       " 'kneepad',\n",
       " 'ladle',\n",
       " 'lamb',\n",
       " 'lampshade',\n",
       " 'laundry_basket',\n",
       " 'lettuce',\n",
       " 'lightning_bug',\n",
       " 'manatee',\n",
       " 'marijuana',\n",
       " 'meatloaf',\n",
       " 'metal_detector',\n",
       " 'minivan',\n",
       " 'modem',\n",
       " 'mosquito',\n",
       " 'muff',\n",
       " 'music_box',\n",
       " 'mussel',\n",
       " 'nightstand',\n",
       " 'okra',\n",
       " 'omelet',\n",
       " 'onion',\n",
       " 'orange',\n",
       " 'orchid',\n",
       " 'ostrich',\n",
       " 'pajamas',\n",
       " 'panther',\n",
       " 'paperweight',\n",
       " 'pear',\n",
       " 'pepper1',\n",
       " 'pheasant',\n",
       " 'pickax',\n",
       " 'pie',\n",
       " 'pigeon',\n",
       " 'piglet',\n",
       " 'pocket',\n",
       " 'pocketknife',\n",
       " 'popcorn',\n",
       " 'popsicle',\n",
       " 'possum',\n",
       " 'pretzel',\n",
       " 'pug',\n",
       " 'punch2',\n",
       " 'purse',\n",
       " 'radish',\n",
       " 'raspberry',\n",
       " 'recorder',\n",
       " 'rhinoceros',\n",
       " 'robot',\n",
       " 'rooster',\n",
       " 'rug',\n",
       " 'sailboat',\n",
       " 'sandal',\n",
       " 'sandpaper',\n",
       " 'sausage',\n",
       " 'scallion',\n",
       " 'scallop',\n",
       " 'scooter',\n",
       " 'seagull',\n",
       " 'seaweed',\n",
       " 'seed',\n",
       " 'skateboard',\n",
       " 'sled',\n",
       " 'sleeping_bag',\n",
       " 'slide',\n",
       " 'slingshot',\n",
       " 'snowshoe',\n",
       " 'spatula',\n",
       " 'spoon',\n",
       " 'station_wagon',\n",
       " 'stethoscope',\n",
       " 'strawberry',\n",
       " 'submarine',\n",
       " 'suit',\n",
       " 't-shirt',\n",
       " 'table',\n",
       " 'taillight',\n",
       " 'tape_recorder',\n",
       " 'television',\n",
       " 'tiara',\n",
       " 'tick',\n",
       " 'tomato_sauce',\n",
       " 'tongs',\n",
       " 'tool',\n",
       " 'top_hat',\n",
       " 'treadmill',\n",
       " 'tube_top',\n",
       " 'turkey',\n",
       " 'unicycle',\n",
       " 'vise',\n",
       " 'volleyball',\n",
       " 'wallpaper',\n",
       " 'walnut',\n",
       " 'wheat',\n",
       " 'wheelchair',\n",
       " 'windshield',\n",
       " 'wine',\n",
       " 'wok']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ldy/miniconda3/envs/BCI/lib/python3.10/site-packages/braindecode/models/base.py:23: UserWarning: EEGNetv4: 'in_chans' is depreciated. Use 'n_chans' instead.\n",
      "  warnings.warn(\n",
      "/home/ldy/miniconda3/envs/BCI/lib/python3.10/site-packages/braindecode/models/base.py:23: UserWarning: EEGNetv4: 'n_classes' is depreciated. Use 'n_outputs' instead.\n",
      "  warnings.warn(\n",
      "/home/ldy/miniconda3/envs/BCI/lib/python3.10/site-packages/braindecode/models/base.py:23: UserWarning: EEGNetv4: 'input_window_samples' is depreciated. Use 'n_times' instead.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of parameters: 1186833\n",
      "self.subjects ['sub-08']\n",
      "exclude_subject None\n",
      "Data tensor shape: torch.Size([200, 63, 250]), label tensor shape: torch.Size([200]), text length: 200, image length: 200\n",
      "features_tensor torch.Size([200, 1024])\n",
      " - Test Loss: 8.4386, Test Accuracy: 0.2300\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.nn import functional as F\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader\n",
    "# os.environ['http_proxy'] = 'http://10.16.35.10:13390' \n",
    "# os.environ['https_proxy'] = 'http://10.16.35.10:13390' \n",
    "\n",
    "os.environ[\"WANDB_API_KEY\"] = \"KEY\"\n",
    "os.environ[\"WANDB_MODE\"] = 'offline'\n",
    "from itertools import combinations\n",
    "\n",
    "import clip\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import tqdm\n",
    "from eegdatasets_leaveone import EEGDataset\n",
    "from eegencoder import eeg_encoder\n",
    "from einops.layers.torch import Rearrange, Reduce\n",
    "from lavis.models.clip_models.loss import ClipLoss\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import random\n",
    "from utils import wandb_logger\n",
    "from torch import Tensor\n",
    "import math\n",
    "from braindecode.models import EEGNetv4, ATCNet, EEGConformer, EEGITNet, ShallowFBCSPNet\n",
    "import csv\n",
    "\n",
    "\n",
    "\n",
    "device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "class EEGNetv4_Encoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.device = device\n",
    "        self.shape = (63, 250)\n",
    "        self.eegnet = EEGNetv4(\n",
    "            in_chans=self.shape[0],\n",
    "            n_classes=1024,\n",
    "            input_window_samples=self.shape[1],\n",
    "            final_conv_length='auto',\n",
    "            pool_mode='mean',\n",
    "            F1=8,\n",
    "            D=20,\n",
    "            F2=160,\n",
    "            kernel_length=4,\n",
    "            third_kernel_size=(4, 2),\n",
    "            drop_prob=0.25\n",
    "        )\n",
    "        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n",
    "        self.loss_func = ClipLoss()\n",
    "    def forward(self, data):\n",
    "        data = data.unsqueeze(0)\n",
    "        data = data.reshape(data.shape[1], data.shape[2], data.shape[3], data.shape[0])\n",
    "        # print(data.shape)\n",
    "        prediction = self.eegnet(data)\n",
    "        return prediction\n",
    "    \n",
    "\n",
    "def get_eegfeatures(sub, eegmodel, dataloader, device, text_features_all, img_features_all, k, mode):\n",
    "    eegmodel.eval()\n",
    "    text_features_all = text_features_all.to(device).float()\n",
    "    img_features_all = img_features_all.to(device).float()\n",
    "    total_loss = 0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    alpha =0.9\n",
    "    top5_correct = 0\n",
    "    top5_correct_count = 0\n",
    "\n",
    "    all_labels = set(range(text_features_all.size(0)))\n",
    "    top5_acc = 0\n",
    "    mse_loss_fn = nn.MSELoss()\n",
    "    ridge_lambda = 0.1\n",
    "    save_features = True\n",
    "    features_list = []  # List to store features    \n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):\n",
    "            eeg_data = eeg_data.to(device)\n",
    "            eeg_data = eeg_data[:, :, :250]\n",
    "            # print(\"eeg_data\", eeg_data.shape)\n",
    "            text_features = text_features.to(device).float()\n",
    "            labels = labels.to(device)\n",
    "            img_features = img_features.to(device).float()\n",
    "            eeg_features = eegmodel(eeg_data).float()\n",
    "            features_list.append(eeg_features)\n",
    "            logit_scale = eegmodel.logit_scale \n",
    "                   \n",
    "            regress_loss =  mse_loss_fn(eeg_features, img_features)\n",
    "            # print(\"eeg_features\", eeg_features.shape)\n",
    "            # print(torch.std(eeg_features, dim=-1))\n",
    "            # print(torch.std(img_features, dim=-1))\n",
    "            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
    "            # loss = (regress_loss + ridge_lambda * l2_norm)       \n",
    "            img_loss = eegmodel.loss_func(eeg_features, img_features, logit_scale)\n",
    "            text_loss = eegmodel.loss_func(eeg_features, text_features, logit_scale)\n",
    "            contrastive_loss = img_loss\n",
    "            # loss = img_loss + text_loss\n",
    "\n",
    "            regress_loss =  mse_loss_fn(eeg_features, img_features)\n",
    "            # print(\"text_loss\", text_loss)\n",
    "            # print(\"img_loss\", img_loss)\n",
    "            # print(\"regress_loss\", regress_loss)            \n",
    "            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
    "            # loss = (regress_loss + ridge_lambda * l2_norm)       \n",
    "            loss = alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10\n",
    "            # print(\"loss\", loss)\n",
    "            total_loss += loss.item()\n",
    "            \n",
    "            for idx, label in enumerate(labels):\n",
    "\n",
    "                possible_classes = list(all_labels - {label.item()})\n",
    "                selected_classes = random.sample(possible_classes, k-1) + [label.item()]\n",
    "                selected_img_features = img_features_all[selected_classes]\n",
    "                \n",
    "\n",
    "                logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T\n",
    "                # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T\n",
    "                # logits_single = (logits_text + logits_img) / 2.0\n",
    "                logits_single = logits_img\n",
    "                # print(\"logits_single\", logits_single.shape)\n",
    "\n",
    "                # predicted_label = selected_classes[torch.argmax(logits_single).item()]\n",
    "                predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \\in {0, 1, ..., n_cls-1}\n",
    "                if predicted_label == label.item():\n",
    "                    correct += 1        \n",
    "                total += 1\n",
    "\n",
    "        if save_features:\n",
    "            features_tensor = torch.cat(features_list, dim=0)\n",
    "            print(\"features_tensor\", features_tensor.shape)\n",
    "            torch.save(features_tensor.cpu(), f\"emb_eeg/{config['encoder_type']}_eeg_features_{sub}_{mode}.pt\")  # Save features as .pt file\n",
    "    average_loss = total_loss / (batch_idx+1)\n",
    "    accuracy = correct / total\n",
    "    return average_loss, accuracy, labels, features_tensor.cpu()\n",
    "\n",
    "from IPython.display import Image, display\n",
    "config = {\n",
    "\"data_path\": \"/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz\",\n",
    "\"project\": \"train_pos_img_text_rep\",\n",
    "\"entity\": \"sustech_rethinkingbci\",\n",
    "\"name\": \"lr=3e-4_img_pos_pro_eeg\",\n",
    "\"lr\": 3e-4,\n",
    "\"epochs\": 50,\n",
    "\"batch_size\": 1024,\n",
    "\"logger\": True,\n",
    "\"encoder_type\":'EEGNetv4_Encoder',\n",
    "}\n",
    "\n",
    "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "data_path = config['data_path']\n",
    "emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')\n",
    "emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')\n",
    "\n",
    "eeg_model = EEGNetv4_Encoder()\n",
    "print('number of parameters:', sum([p.numel() for p in eeg_model.parameters()]))\n",
    "\n",
    "#####################################################################################\n",
    "\n",
    "# eeg_model.load_state_dict(torch.load(\"/home/ldy/Workspace/Reconstruction/models/contrast/sub-08/01-30_00-44/40.pth\"))\n",
    "eeg_model.load_state_dict(torch.load(\"/home/ldy/Workspace/BrainAligning_retrieval/models/contrast/EEGNetv4_Encoder/sub-08/02-10_04-58/40.pth\"))\n",
    "eeg_model = eeg_model.to(device)\n",
    "sub = 'sub-08'\n",
    "\n",
    "#####################################################################################\n",
    "\n",
    "test_dataset = EEGDataset(data_path, subjects= [sub], train=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=config[\"batch_size\"], shuffle=False, num_workers=0)\n",
    "text_features_test_all = test_dataset.text_features\n",
    "img_features_test_all = test_dataset.img_features\n",
    "test_loss, test_accuracy,labels, eeg_features_test = get_eegfeatures(sub, eeg_model, test_loader, device, text_features_test_all, img_features_test_all,k=200, mode=\"test\")\n",
    "print(f\" - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "self.subjects ['sub-08']\n",
      "exclude_subject None\n",
      "data_tensor torch.Size([66160, 63, 250])\n",
      "Data tensor shape: torch.Size([66160, 63, 250]), label tensor shape: torch.Size([66160]), text length: 1654, image length: 16540\n",
      "features_tensor torch.Size([66160, 1024])\n",
      " - Test Loss: 7.4921, Test Accuracy: 0.0050\n"
     ]
    }
   ],
   "source": [
    "#####################################################################################\n",
    "train_dataset = EEGDataset(data_path, subjects= [sub], train=True)\n",
    "train_loader = DataLoader(train_dataset, batch_size=config[\"batch_size\"], shuffle=False, num_workers=0)\n",
    "text_features_test_all = train_dataset.text_features\n",
    "img_features_test_all = train_dataset.img_features\n",
    "\n",
    "train_loss, train_accuracy, labels, eeg_features_train = get_eegfeatures(sub, eeg_model, train_loader, device, text_features_test_all, img_features_test_all,k=200, mode=\"train\")\n",
    "print(f\" - Test Loss: {train_loss:.4f}, Test Accuracy: {train_accuracy:.4f}\")\n",
    "#####################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import open_clip\n",
    "from matplotlib.font_manager import FontProperties\n",
    "\n",
    "import sys\n",
    "from diffusion_prior import *\n",
    "from custom_pipeline import *\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
    "device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)\n",
    "emb_eeg = torch.load(f\"/home/ldy/Workspace/Generation/emb_eeg/{config['encoder_type']}_eeg_features_{sub}_train.pt\")\n",
    "emb_eeg_test = torch.load(f\"/home/ldy/Workspace/Generation/emb_eeg/{config['encoder_type']}_eeg_features_{sub}_test.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([66160, 1024]), torch.Size([200, 1024]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb_eeg.shape, emb_eeg_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.6689, -0.0679, -0.0944,  ..., -0.4610, -0.4619,  0.6076],\n",
       "        [-1.2554,  2.1959,  0.2682,  ...,  0.1779, -0.9086, -0.6030],\n",
       "        [-0.0324,  0.0460, -0.0866,  ..., -0.4524,  0.2912, -0.2215],\n",
       "        ...,\n",
       "        [-0.8543,  0.7007,  1.0704,  ..., -0.1372,  0.0445,  1.1926],\n",
       "        [ 0.0635, -0.1275,  0.4409,  ..., -0.2616, -0.4191,  1.0185],\n",
       "        [ 0.0479, -0.3220,  0.2826,  ...,  0.3011, -0.4032,  0.4529]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eeg_features_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusion_prior import *\n",
    "class EmbeddingDataset(Dataset):\n",
    "\n",
    "    def __init__(self, c_embeddings=None, h_embeddings=None, h_embeds_uncond=None, cond_sampling_rate=0.5):\n",
    "        self.c_embeddings = c_embeddings\n",
    "        self.h_embeddings = h_embeddings\n",
    "        self.N_cond = 0 if self.h_embeddings is None else len(self.h_embeddings)\n",
    "        self.h_embeds_uncond = h_embeds_uncond\n",
    "        self.N_uncond = 0 if self.h_embeds_uncond is None else len(self.h_embeds_uncond)\n",
    "        self.cond_sampling_rate = cond_sampling_rate\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.N_cond\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return {\n",
    "            \"c_embedding\": self.c_embeddings[idx],\n",
    "            \"h_embedding\": self.h_embeddings[idx]\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9675648\n",
      "epoch: 0, loss: 1.1337378648611216\n",
      "epoch: 1, loss: 0.9277484618700468\n",
      "epoch: 2, loss: 0.7346047126329862\n",
      "epoch: 3, loss: 0.5879583936471205\n",
      "epoch: 4, loss: 0.48145529077603266\n",
      "epoch: 5, loss: 0.4021404353471903\n",
      "epoch: 6, loss: 0.34819340018125683\n",
      "epoch: 7, loss: 0.3164940769855793\n",
      "epoch: 8, loss: 0.29743201365837685\n",
      "epoch: 9, loss: 0.2812065954391773\n",
      "epoch: 10, loss: 0.2707386599137233\n",
      "epoch: 11, loss: 0.25827456024976875\n",
      "epoch: 12, loss: 0.24789529328162854\n",
      "epoch: 13, loss: 0.23868788549533257\n",
      "epoch: 14, loss: 0.2311222133728174\n",
      "epoch: 15, loss: 0.22522189250359168\n",
      "epoch: 16, loss: 0.21735877761474023\n",
      "epoch: 17, loss: 0.21265713274478912\n",
      "epoch: 18, loss: 0.21124263864297133\n",
      "epoch: 19, loss: 0.20706446170806886\n",
      "epoch: 20, loss: 0.20532444371626926\n",
      "epoch: 21, loss: 0.2046730555020846\n",
      "epoch: 22, loss: 0.20209349187520834\n",
      "epoch: 23, loss: 0.20230029294124016\n",
      "epoch: 24, loss: 0.19958276244310233\n",
      "epoch: 25, loss: 0.1956863997074274\n",
      "epoch: 26, loss: 0.19550649454960456\n",
      "epoch: 27, loss: 0.1932537264548815\n",
      "epoch: 28, loss: 0.1931295592051286\n",
      "epoch: 29, loss: 0.19213039852105654\n",
      "epoch: 30, loss: 0.19155625036129584\n",
      "epoch: 31, loss: 0.18909929784444662\n",
      "epoch: 32, loss: 0.18992915061803964\n",
      "epoch: 33, loss: 0.188588475722533\n",
      "epoch: 34, loss: 0.18476488429766436\n",
      "epoch: 35, loss: 0.18678866395583518\n",
      "epoch: 36, loss: 0.18525446424117456\n",
      "epoch: 37, loss: 0.18486184248557458\n",
      "epoch: 38, loss: 0.18248385351437788\n",
      "epoch: 39, loss: 0.1811525195837021\n",
      "epoch: 40, loss: 0.18144333362579346\n",
      "epoch: 41, loss: 0.17867409541056706\n",
      "epoch: 42, loss: 0.17740659071848944\n",
      "epoch: 43, loss: 0.1764389482828287\n",
      "epoch: 44, loss: 0.17700819281431346\n",
      "epoch: 45, loss: 0.17582294803399307\n",
      "epoch: 46, loss: 0.17429512349458842\n",
      "epoch: 47, loss: 0.1736076641541261\n",
      "epoch: 48, loss: 0.17225329073575826\n",
      "epoch: 49, loss: 0.17144493048007672\n",
      "epoch: 50, loss: 0.17105715228961063\n",
      "epoch: 51, loss: 0.17152268909491025\n",
      "epoch: 52, loss: 0.16880368773753826\n",
      "epoch: 53, loss: 0.170586480314915\n",
      "epoch: 54, loss: 0.16783449489336746\n",
      "epoch: 55, loss: 0.1691400450009566\n",
      "epoch: 56, loss: 0.1658798662515787\n",
      "epoch: 57, loss: 0.16694954060591183\n",
      "epoch: 58, loss: 0.16697401931652656\n",
      "epoch: 59, loss: 0.165232030244974\n",
      "epoch: 60, loss: 0.16396311521530152\n",
      "epoch: 61, loss: 0.16203268628854017\n",
      "epoch: 62, loss: 0.16698091740791615\n",
      "epoch: 63, loss: 0.16399981792156512\n",
      "epoch: 64, loss: 0.16448446810245515\n",
      "epoch: 65, loss: 0.16406533626409678\n",
      "epoch: 66, loss: 0.16220627793898948\n",
      "epoch: 67, loss: 0.1626904058914918\n",
      "epoch: 68, loss: 0.16437909328020536\n",
      "epoch: 69, loss: 0.1609313971721209\n",
      "epoch: 70, loss: 0.1631575719668315\n",
      "epoch: 71, loss: 0.1599584792669003\n",
      "epoch: 72, loss: 0.1629338324069977\n",
      "epoch: 73, loss: 0.16070766242650839\n",
      "epoch: 74, loss: 0.15935472937730644\n",
      "epoch: 75, loss: 0.16142174509855417\n",
      "epoch: 76, loss: 0.15977356502642998\n",
      "epoch: 77, loss: 0.15964984893798828\n",
      "epoch: 78, loss: 0.15815440118312835\n",
      "epoch: 79, loss: 0.1591795611840028\n",
      "epoch: 80, loss: 0.1592421217606618\n",
      "epoch: 81, loss: 0.15924676289925208\n",
      "epoch: 82, loss: 0.15849402982455033\n",
      "epoch: 83, loss: 0.1606162871305759\n",
      "epoch: 84, loss: 0.15699961208380186\n",
      "epoch: 85, loss: 0.15537095528382522\n",
      "epoch: 86, loss: 0.1578238546848297\n",
      "epoch: 87, loss: 0.15665148473702944\n",
      "epoch: 88, loss: 0.15740657655092385\n",
      "epoch: 89, loss: 0.1553252440232497\n",
      "epoch: 90, loss: 0.15652525012309734\n",
      "epoch: 91, loss: 0.15760714686833896\n",
      "epoch: 92, loss: 0.15551115549527683\n",
      "epoch: 93, loss: 0.15572300988894242\n",
      "epoch: 94, loss: 0.15536696429436023\n",
      "epoch: 95, loss: 0.1551545427395747\n",
      "epoch: 96, loss: 0.15626645042346074\n",
      "epoch: 97, loss: 0.15537708309980539\n",
      "epoch: 98, loss: 0.15510759628736057\n",
      "epoch: 99, loss: 0.1543263880106119\n",
      "epoch: 100, loss: 0.15425271919140449\n",
      "epoch: 101, loss: 0.15487114764176882\n",
      "epoch: 102, loss: 0.15255439717036026\n",
      "epoch: 103, loss: 0.1513884425163269\n",
      "epoch: 104, loss: 0.1535035007275068\n",
      "epoch: 105, loss: 0.15325337877640358\n",
      "epoch: 106, loss: 0.1540354270201463\n",
      "epoch: 107, loss: 0.154390015739661\n",
      "epoch: 108, loss: 0.15147823072396793\n",
      "epoch: 109, loss: 0.15220667123794557\n",
      "epoch: 110, loss: 0.1523587139753195\n",
      "epoch: 111, loss: 0.15091401178103228\n",
      "epoch: 112, loss: 0.15136499611230997\n",
      "epoch: 113, loss: 0.1518503189086914\n",
      "epoch: 114, loss: 0.1512597544835164\n",
      "epoch: 115, loss: 0.14968440463909735\n",
      "epoch: 116, loss: 0.1488885929951301\n",
      "epoch: 117, loss: 0.15059134180729206\n",
      "epoch: 118, loss: 0.1497446296306757\n",
      "epoch: 119, loss: 0.14996495407361252\n",
      "epoch: 120, loss: 0.15109946108781375\n",
      "epoch: 121, loss: 0.14965039904300984\n",
      "epoch: 122, loss: 0.1514803063410979\n",
      "epoch: 123, loss: 0.14924780818132255\n",
      "epoch: 124, loss: 0.14947404265403746\n",
      "epoch: 125, loss: 0.14996229318472054\n",
      "epoch: 126, loss: 0.1485686224240523\n",
      "epoch: 127, loss: 0.14805283431823438\n",
      "epoch: 128, loss: 0.14889377791147965\n",
      "epoch: 129, loss: 0.14838851529818314\n",
      "epoch: 130, loss: 0.1487965356845122\n",
      "epoch: 131, loss: 0.14777726393479568\n",
      "epoch: 132, loss: 0.14756032939140612\n",
      "epoch: 133, loss: 0.14845420878667098\n",
      "epoch: 134, loss: 0.1478582792557203\n",
      "epoch: 135, loss: 0.1495072433581719\n",
      "epoch: 136, loss: 0.14733073848944445\n",
      "epoch: 137, loss: 0.15026113734795496\n",
      "epoch: 138, loss: 0.14988462902032412\n",
      "epoch: 139, loss: 0.14790701659826133\n",
      "epoch: 140, loss: 0.14759095494563762\n",
      "epoch: 141, loss: 0.14727302835537837\n",
      "epoch: 142, loss: 0.14784078666797051\n",
      "epoch: 143, loss: 0.14688316400234516\n",
      "epoch: 144, loss: 0.1472667499230458\n",
      "epoch: 145, loss: 0.14967363591377553\n",
      "epoch: 146, loss: 0.14848633454396173\n",
      "epoch: 147, loss: 0.14954848610437832\n",
      "epoch: 148, loss: 0.14868795986358935\n",
      "epoch: 149, loss: 0.1491350281697053\n"
     ]
    }
   ],
   "source": [
    "dataset = EmbeddingDataset(\n",
    "    c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, \n",
    "    # h_embeds_uncond=h_embeds_imgnet\n",
    ")\n",
    "dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)\n",
    "diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
    "# number of parameters\n",
    "print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
    "pipe = Pipe(diffusion_prior, device=device)\n",
    "\n",
    "# load pretrained model\n",
    "model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
    "pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Couldn't connect to the Hub: 504 Server Error: Gateway Time-out for url: https://huggingface.co/api/models/stabilityai/sdxl-turbo.\n",
      "Will try to load from local cache.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1aaf9a052c4471ebf7ec72a126b1bf1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 247.14it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0c00f7f085e4f8aba52e8ca7ab73633",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 257.56it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26734873e7734d16b8f55195b07f7f9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.17it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2ac18d49d004101ad78a1db658182d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.82it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "524c591e92bb4643b7e97ba380027278",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 234.46it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4eec485b4a1445d6ab1704dd3b5a2179",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 258.96it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0f8666997ac948e99396ee8033b4b030",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 258.36it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc957e6642234fa680fd376dab669222",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 258.16it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "102d173802ea46b2a3be5413d8d40069",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.88it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2644da46940a4aeaaccb796c352bce8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.86it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1025715eaf5a42b5957ad3c305d26167",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 228.01it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f99bc32314b6464788a7d0d3fee31b30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 248.91it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e22d0cc83874c178b332ad97b566c8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 249.80it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3f0755e9efe4ed09662e8b184315960",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 246.18it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d29d736249c4ff7bd3f357b33bc4fd2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 239.76it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd4dbc58c241406993fe13c238e40fda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.15it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d280f5953cd7479f80b50c40d7a26238",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 250.26it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a2de32ce6a740a8963b0952cc467005",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.09it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c49ff22ff7c44034aa7c970ce8f7c047",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 250.06it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3de8c6e20dfe4fa4a0cdc81615b8fac7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.38it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cea8406f617a4337b9ae918671c49578",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 230.41it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e2f26a4518c94bab980cbc4688a1821a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.36it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1cb3273977234e44b1a9a10e493d9212",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 257.12it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8846b53013564843b2e049f1f268cf2f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 229.61it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fea4a7baab4d43f98a701e08b3da2d17",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.01it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12033088d89c4176bf6c511cc0c0008e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.45it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "78775447c3c64cca9b93f0908d2a6015",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.27it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ec9e783ca314b428422439ef99c4751",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 123.73it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11bbbdd5ba04416f90f5103cc338c9b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.47it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "668e02034dec4bc89c27675dfa39cead",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.90it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b932fc41ab6d4aa1bda1c990bc240426",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.09it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1e3fd60cec749f2bb7ff93cf4fc93e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.54it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0abc5e2afb5e4b0ca4ebc3f5f407818f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.09it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39f382f2d47a41e4b31cc680a69b2c39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 218.48it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ddbdf267a9148849253ae8ad8fc18ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.98it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "612cd7d10e4e4a439e4affd23284af88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 241.64it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7d435a32cea419484979c00d321ddf2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 228.53it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0efd6862db544f349e75770184c4d945",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 247.11it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06a0d9582ec049bfa70372f42b0ad5c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 223.37it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7864e5da375e4e66843e01b716101ca6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.06it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87bb2c8793d648f0ac6279e9e6d974e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 242.74it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34d4b1a07b0d435bbcce16240681b97e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 241.42it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a73bf1b37d66432b902957dee67a8bfb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 240.87it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "640222362bff4c12b30bcf5c683c001a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 233.99it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "061600bf670840c18f42021e8ef6ae40",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 236.18it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2284d0c92454048b38482c28edb7487",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.00it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f282e016aba84373ad9839d3207c3e77",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.59it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04239c2d4daf4d9f86977fd5673a0351",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 226.57it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41cda47a708d46fa9ed02fdf623ff43f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.34it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13a8faf4f8454f989fdcfde69191ff63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 237.57it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f887620668d415696a2f8dfc672a1ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 257.04it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "972e007ea00647259ff5dd84eb2955c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 259.62it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c979262491d1416f82caf6423d88e6b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 237.70it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b92b630c3ae4720b6aadecd8be1438f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 259.87it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6596ab5dcd5e4d1980b3d42100ed110e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 250.52it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d9d122edad34913b515515c105752ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.31it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f74d8a458c68495a9f21ad2b5823dfbd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.00it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8af5ddadd5224f18be6cd010bea6026f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 226.85it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7d315a57018d4afab5c4d2d955a88235",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 130.55it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3aab120d839c4fe3ad81242f55b2b64a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.11it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e14bbe70bbd64be7b85a6a82315ed968",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 234.40it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "508c1a13ed954d068eebea94116b622f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.12it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "816d6d467d1744d8aaa5c0313a474cc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 233.19it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca6aa87e560644e58c7e4384e97aa250",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 234.38it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "384c7075853d4c86a9ae3395c3b61532",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.54it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "064991500ef0444b89a866c394b1e310",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 240.72it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9471f48bacd546499df04a0a31993045",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.47it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4592853e833148f78a9db0006522c7a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 229.01it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1c7e123a75a0448097af4e627507940d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 236.31it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb37fe1a823d42b8969d1fcb5553244a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 237.98it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5df90e4358c746a4ac82bb2db1e1ecc4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 249.88it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9d5e9b765d141d7bb0a60a92c0cc21d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.24it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd92d5bc794c44cd8058e036e38e2de0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.98it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d92f9edb7824521b5b7d455a8f0d2dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.77it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1b31e3f863f48d285e7ddd6f3608e16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 236.53it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "733ee7b0f5264d8cbb1318868898feb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.86it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2aa6fbe0431c489391fbee2b92af7a33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.47it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3d979573a334a55816950a665938146",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 227.59it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c2b16d512374c328f4492943489d91e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.97it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "301244c529214c1d9ad3ca77458a7520",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.91it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4fb38fd149fc4c0f944aa79dadaff69f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 214.04it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2438f7af8f064618ad6086d3b2a98a96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.62it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ffc728de2403431fb06128c0378fa165",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.41it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42fb0f1853d94385a11d4e11777aae08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 228.21it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "372ac69b4ed443d3ab0b109c432f338b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.56it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de1dc7080f804ff8b04fd2794aab888f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 249.51it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c903dc3379264309a96a0dc545fc604c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 236.66it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9bcec69fca324dfdbb71f865bc03fb1c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 241.23it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da29d96a8330453d9a155aeb30b6f072",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 250.90it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "293dcbd077d34c28a90165d57b5cf7e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.91it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3e9e8da58ad4ff38db1da7f8c234f79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 124.99it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "354cd4a14f4f48faabd6ff34fa4c7877",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.79it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54a1886b87cb4dbdb60c46f3f373e7b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 243.63it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff6ba27830114e8195ee72a172896a03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 234.63it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "389c2c4222cb4957934dfb7ecefafd08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 246.35it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d28ce92b68094b53ae53dd675456388c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.43it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04011f4b4a01429a82e20f3537a6fab0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.96it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24053569fe5b4804aeb2ee6d965e908c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.10it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f69acb99d5444865be1fc5e7103e6575",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 244.43it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62c11fd9a6994b6797b7a4ef264ca205",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 234.28it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "681c96c3d6fa4f0fa3f221a5897c2144",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.68it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d5f2aeb26634132bda722a134f57706",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.05it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d29488a0421a4a25af4533e2f571a536",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.42it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "718ad7e8592745248a1380393e6c7c08",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.72it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c58a8d6ef7d430dbc1a66031b0be48a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.30it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0563678c2938436896b3af6ea69de3a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 242.30it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09a0f78acf254d808cc09f6b31418294",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 227.39it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb9a16d5a4844d2c95606be2634d2b6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 250.15it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7a059b1d5bb43ec946ee1a69ed5fb83",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.58it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bcb454d0fe7a46ffbd0f7ba5949d49ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 257.88it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fec2f69964c455096c23d22f5d4e94d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.50it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7677bdb2d9947cb9dc7036916497b9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.62it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ddd884ce1a8470681d30866e2997ec8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.87it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3456b0c4e204423590a924ad4710a79a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.67it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de1c61e65ef244a9bd10071644705d89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.36it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "870ebbd0b3794fab837ab0a66dd4f370",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.34it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6bac693aef24b09b96b8192a66af81b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 257.44it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02039cda5b654f268a2339497f27e5d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.24it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8574db4bd3254dbdbfed1a5fdcfce1cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.19it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "731aa7be7d2645ce81947aac8d78c9a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.96it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "943e28c6fbc24d72bfbcf49ddda096c6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.90it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7874ca03894f43ebb0b054646a1073c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.93it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "63ef489e12c341f5a5a3f269b013921f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.13it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9aaf7ab4324643f89759428635679194",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 125.50it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b2baedfd5e643289a85bf7b87f93a1a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.74it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0ad2f10799e4bc4a6278fb4829d32e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 247.48it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2d9906ce7d64265acb3d085d53a7bc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 247.34it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8ff40dc362d49b8904c8a7cbd4fe327",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.07it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4fff3a6f7434685aa2168c7901d2c8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.47it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a145b0bdb29c457bb7def0fbfee36f4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.40it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d6afd40c9d894e98bdf09abcbbc65a22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.78it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28c70ec5368c403c906e6e123c846b38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.51it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27871ea93e7d47288129e4de65eec209",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.82it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e9c3af3d77d4e839ddf8abb94e18b42",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.53it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "508fc712453f4da899970def561aedb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.85it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9bf9c188086947d4bda8e51133052280",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 253.69it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "471bab1e3c074bddbad81aa0fc91b5cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.59it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e79f999ac78d4e03903472b0d95fe7ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 258.85it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "046d3b3093b84eb88a0626f014edecfd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.13it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "38ef3db69e7845f1abc3ec71e58766e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 259.89it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f11890f9f2740d593bd481fccfe4df1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.94it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a908088b3520415d9b5bd9840fce05ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 255.92it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6dda43fded55474594056d041f53d95c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.52it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4edeff4eeba641309e323a4ce8ba859a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 220.42it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0ec3321b6194d20b8bf96930cd90925",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.62it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3794c3c53e1741e5ac1558b7863a7b1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.13it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb0370b13fbe4be39258c21c6d35e5c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 258.30it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2443cc67e32948439a38224dff7e403e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.07it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1c1ca8603a9455dbabdaeb7731db2d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 228.27it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad776510ca3d4a10b8d5d27009a08ebf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 247.18it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ec1985d634f341fa80bbba16dbc7c26f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 246.50it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "922ae77306564dc8a19afdfc91ce0549",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.83it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11868bb8f3dc418f907b6969421a818c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 240.69it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "695cecc50a0a42b5b5046163e36fe6a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 254.24it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e77bf8978e2445acaa32ce8b3cd1ab96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.01it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d5e40ad0357e435e95980e4ad54d9094",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 216.19it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a737380d37e5455dbfca62dc8991bba0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 118.74it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2667af5b67df4ea6892ba017cf50f644",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 237.83it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7cb3cba7d9e4bb7914275d05338fef8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 231.89it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2d81f798da340838b76b103bce453a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.88it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d87d418e836c4caa979b436c49a39427",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 243.14it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "191bbdf9b55e4bc0814627362a4ee99f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.21it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d43db3c39174cb5bb9b786a0c556a52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 171.81it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "408f86d3fa44467195db8e49e3eebea7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 240.88it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "812dcf1595de4bb0a8986f6fe494fce5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 232.56it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "797f33ad8e954f79b399b63f2d22012d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 227.28it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3692501e6a9e4fe08fe7baa9620c4f3a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 230.84it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d61edd54fbc4f00a9bb72e6c420b3bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 248.06it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "114145f6297c4de18008358dfcf2d92d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 225.98it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8052c27504f405ba51d2bec21a137d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 241.52it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "38f323d606ad4402b934470031a5461c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.52it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47dc0e3c60a1449d9331c06163fa179c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 217.67it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "840fc006ac334a58b6fc0995c4ae46d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 245.98it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c62c770b4a04f50b015914fc6f52256",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 244.74it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41afacab3c8c44388fe511c819a81e75",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 237.88it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ebcb9680cdc4cb69f7a53fc896bb1b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 240.60it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3400aff17a1449e08a7e3f450dd7d718",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 239.31it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98341b0083b443fca0a6423bf316fcce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 249.34it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b696a4bda3a44fa98e3d6aad72738a8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 233.36it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c1f58bc3f4b474aba16c6b2caa14d0b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 240.47it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f0d03c151a341d1b4df30af3a055ceb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 243.49it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "658821db72c34523b2c84f4f482c6b5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 202.36it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41d5d7ef548c4abe886e8858835c6c87",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 233.53it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96459e52d7c94f5f959e19a422aba39c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 236.42it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09930788dde24d34b7bb33b7f009938b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 222.06it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27c8c8135bc340faa08828647d94b712",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.76it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "43b060732b9e4052956329496851a75c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 248.66it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c95feb87b2e487d9dc4dd06d9163a2d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 247.35it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c3ebc636a757473dbb86d1a5b34abce9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 116.90it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be73a804ce4948aebe485974d7dd0ffc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 230.68it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c7fbe3d42d54fa6ae47d34e12e7afc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 238.50it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7164115353ea45cabda1a3f5066ece43",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 250.14it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6a62983b2834efea26f22ba179e1f95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 248.30it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "037446d117af4408bc2567a995950158",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.84it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a95cb244d64046d79ae98a5209152555",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 256.32it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "88bb9569d4e344b1940184b9804668f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 225.10it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d52ee5bb1b14add88d3be62732bec13",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.20it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "530a2fc8a9fe49269a4c50a2acbc8798",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 229.40it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c766ba2ae696417e8aa0ff576d9ff8bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 252.02it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14c3974747ac476db750a70870d798be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "50it [00:00, 251.67it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c95435e3fc24443ebcceed21d84d5214",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "pipe.diffusion_prior.load_state_dict(torch.load(f\"./fintune_ckpts/{config['encoder_type']}/{sub}/{model_name}.pt\", map_location=device))\n",
    "# save_path = f'./fintune_ckpts/{config[\"encoder_type\"]}/{sub}/{model_name}.pt'\n",
    "\n",
    "# directory = os.path.dirname(save_path)\n",
    "\n",
    "# Create the directory if it doesn't exist\n",
    "# os.makedirs(directory, exist_ok=True)\n",
    "# torch.save(pipe.diffusion_prior.state_dict(), save_path)\n",
    "from PIL import Image\n",
    "import os\n",
    "\n",
    "# Assuming generator.generate returns a PIL Image\n",
    "generator = Generator4Embeds(num_inference_steps=4, device=device)\n",
    "\n",
    "directory = f\"Generation/{config['encoder_type']}/generated_imgs/{sub}\"\n",
    "for k in range(200):\n",
    "    eeg_embeds = emb_eeg_test[k:k+1]\n",
    "    h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)\n",
    "    for j in range(1):\n",
    "        image = generator.generate(h.to(dtype=torch.float16))\n",
    "        # Construct the save path for each image\n",
    "        path = f'{directory}/{texts[k]}/{j}.png'\n",
    "        # Ensure the directory exists\n",
    "        os.makedirs(os.path.dirname(path), exist_ok=True)\n",
    "        # Save the PIL Image\n",
    "        image.save(path)\n",
    "        # print(f'Image saved to {path}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# # os.environ['http_proxy'] = 'http://10.16.35.10:13390' \n",
    "# # os.environ['https_proxy'] = 'http://10.16.35.10:13390' \n",
    "# # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
    "# # os.environ['PATH'] += os.pathsep + '/usr/local/texlive/2023/bin/x86_64-linux'\n",
    "\n",
    "# import torch\n",
    "# from torch import nn\n",
    "# import torch.nn.functional as F\n",
    "# import torchvision.transforms as transforms\n",
    "# import matplotlib.pyplot as plt\n",
    "# import open_clip\n",
    "# from matplotlib.font_manager import FontProperties\n",
    "\n",
    "# import sys\n",
    "# from diffusion_prior import *\n",
    "# from custom_pipeline import *\n",
    "\n",
    "# device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load eeg and image embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # load image embeddings\n",
    "# data = torch.load('/home/ldy/Workspace/THINGS/CLIP/ViT-H-14_features_test.pt', map_location='cuda:3')\n",
    "# emb_img_test = data['img_features']\n",
    "\n",
    "# # load image embeddings\n",
    "# data = torch.load('/home/ldy/Workspace/THINGS/CLIP/ViT-H-14_features_train.pt', map_location='cuda:3')\n",
    "# emb_img_train = data['img_features']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')\n",
    "# emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')\n",
    "\n",
    "# torch.save(emb_img_test.cpu().detach(), 'variables/ViT-H-14_features_test.pt')\n",
    "# torch.save(emb_img_train.cpu().detach(), 'variables/ViT-H-14_features_train.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# emb_img_test.shape, emb_img_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1654clsx10imgsx4trials=66160\n",
    "# emb_eeg = torch.load('/home/ldy/Workspace/Reconstruction/ATM_S_eeg_features_sub-08.pt')\n",
    "\n",
    "# emb_eeg_test = torch.load('/home/ldy/Workspace/Reconstruction/ATM_S_eeg_features_sub-08_test.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# emb_eeg.shape, emb_eeg_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training prior diffusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class EmbeddingDataset(Dataset):\n",
    "\n",
    "    def __init__(self, c_embeddings=None, h_embeddings=None, h_embeds_uncond=None, cond_sampling_rate=0.5):\n",
    "        self.c_embeddings = c_embeddings\n",
    "        self.h_embeddings = h_embeddings\n",
    "        self.N_cond = 0 if self.h_embeddings is None else len(self.h_embeddings)\n",
    "        self.h_embeds_uncond = h_embeds_uncond\n",
    "        self.N_uncond = 0 if self.h_embeds_uncond is None else len(self.h_embeds_uncond)\n",
    "        self.cond_sampling_rate = cond_sampling_rate\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.N_cond\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return {\n",
    "            \"c_embedding\": self.c_embeddings[idx],\n",
    "            \"h_embedding\": self.h_embeddings[idx]\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_img_train_4.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path_data = '/mnt/dataset0/weichen/projects/visobj/proposals/mise/data'\n",
    "# image_features = torch.load(os.path.join(path_data, 'openclip_emb/emb_imgnet.pt')) # 'emb_imgnet' or 'image_features'\n",
    "# h_embeds_imgnet = image_features['image_features']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "dataset = EmbeddingDataset(\n",
    "    c_embeddings=emb_eeg, h_embeddings=emb_img_train_4, \n",
    "    # h_embeds_uncond=h_embeds_imgnet\n",
    ")\n",
    "print(len(dataset))\n",
    "dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# diffusion_prior = DiffusionPrior(dropout=0.1)\n",
    "diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
    "# number of parameters\n",
    "print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
    "pipe = Pipe(diffusion_prior, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load pretrained model\n",
    "model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
    "pipe.diffusion_prior.load_state_dict(torch.load(f\"./fintune_ckpts/{config['encoder_type']}/{sub}/{model_name}.pt\", map_location=device))\n",
    "# pipe.train(dataloader, num_epochs=150, learning_rate=1e-3) # to 0.142 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save model\n",
    "# torch.save(pipe.diffusion_prior.state_dict(), f'./ckpts/{model_name}.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generating by eeg embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save model\n",
    "# torch.save(pipe.diffusion_prior.state_dict(), f'./ckpts/{model_name}.pt')\n",
    "from IPython.display import Image, display\n",
    "generator = Generator4Embeds(num_inference_steps=4, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path of ground truth: /home/ldy/Workspace/THINGS/images_set/test_images\n",
    "k = 99\n",
    "image_embeds = emb_img_test[k:k+1]\n",
    "print(\"image_embeds\", image_embeds.shape)\n",
    "image = generator.generate(image_embeds)\n",
    "display(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image_embeds = emb_eeg_test[k:k+1]\n",
    "# print(\"image_embeds\", image_embeds.shape)\n",
    "# image = generator.generate(image_embeds)\n",
    "# display(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generating by eeg informed image embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# k = 0\n",
    "eeg_embeds = emb_eeg_test[k:k+1]\n",
    "print(\"image_embeds\", eeg_embeds.shape)\n",
    "h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)\n",
    "image = generator.generate(h.to(dtype=torch.float16))\n",
    "display(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "BCI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}