5300 lines (5299 with data), 128.9 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import numpy as np\n",
"import os\n",
"import clip\n",
"from torch.nn import functional as F\n",
"import torch.nn as nn\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"train = False\n",
"classes = None\n",
"pictures= None\n",
"\n",
"def load_data():\n",
" data_list = []\n",
" label_list = []\n",
" texts = []\n",
" images = []\n",
" \n",
" if train:\n",
" text_directory = \"/home/ldy/Workspace/THINGS/images_set/training_images\" \n",
" else:\n",
" text_directory = \"/home/ldy/Workspace/THINGS/images_set/test_images\"\n",
"\n",
" dirnames = [d for d in os.listdir(text_directory) if os.path.isdir(os.path.join(text_directory, d))]\n",
" dirnames.sort()\n",
" \n",
" if classes is not None:\n",
" dirnames = [dirnames[i] for i in classes]\n",
"\n",
" for dir in dirnames:\n",
"\n",
" try:\n",
" idx = dir.index('_')\n",
" description = dir[idx+1:]\n",
" except ValueError:\n",
" print(f\"Skipped: {dir} due to no '_' found.\")\n",
" continue\n",
" \n",
" new_description = f\"{description}\"\n",
" texts.append(new_description)\n",
"\n",
" if train:\n",
" img_directory = \"/home/ldy/Workspace/THINGS/images_set/training_images\"\n",
" else:\n",
" img_directory =\"/home/ldy/Workspace/THINGS/images_set/test_images\"\n",
" \n",
" all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]\n",
" all_folders.sort()\n",
"\n",
" if classes is not None and pictures is not None:\n",
" images = []\n",
" for i in range(len(classes)):\n",
" class_idx = classes[i]\n",
" pic_idx = pictures[i]\n",
" if class_idx < len(all_folders):\n",
" folder = all_folders[class_idx]\n",
" folder_path = os.path.join(img_directory, folder)\n",
" all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
" all_images.sort()\n",
" if pic_idx < len(all_images):\n",
" images.append(os.path.join(folder_path, all_images[pic_idx]))\n",
" elif classes is not None and pictures is None:\n",
" images = []\n",
" for i in range(len(classes)):\n",
" class_idx = classes[i]\n",
" if class_idx < len(all_folders):\n",
" folder = all_folders[class_idx]\n",
" folder_path = os.path.join(img_directory, folder)\n",
" all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
" all_images.sort()\n",
" images.extend(os.path.join(folder_path, img) for img in all_images)\n",
" elif classes is None:\n",
" images = []\n",
" for folder in all_folders:\n",
" folder_path = os.path.join(img_directory, folder)\n",
" all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
" all_images.sort() \n",
" images.extend(os.path.join(folder_path, img) for img in all_images)\n",
" else:\n",
"\n",
" print(\"Error\")\n",
" return texts, images\n",
"texts, images = load_data()\n",
"# images"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['aircraft_carrier',\n",
" 'antelope',\n",
" 'backscratcher',\n",
" 'balance_beam',\n",
" 'banana',\n",
" 'baseball_bat',\n",
" 'basil',\n",
" 'basketball',\n",
" 'bassoon',\n",
" 'baton4',\n",
" 'batter',\n",
" 'beaver',\n",
" 'bench',\n",
" 'bike',\n",
" 'birthday_cake',\n",
" 'blowtorch',\n",
" 'boat',\n",
" 'bok_choy',\n",
" 'bonnet',\n",
" 'bottle_opener',\n",
" 'brace',\n",
" 'bread',\n",
" 'breadbox',\n",
" 'bug',\n",
" 'buggy',\n",
" 'bullet',\n",
" 'bun',\n",
" 'bush',\n",
" 'calamari',\n",
" 'candlestick',\n",
" 'cart',\n",
" 'cashew',\n",
" 'cat',\n",
" 'caterpillar',\n",
" 'cd_player',\n",
" 'chain',\n",
" 'chaps',\n",
" 'cheese',\n",
" 'cheetah',\n",
" 'chest2',\n",
" 'chime',\n",
" 'chopsticks',\n",
" 'cleat',\n",
" 'cleaver',\n",
" 'coat',\n",
" 'cobra',\n",
" 'coconut',\n",
" 'coffee_bean',\n",
" 'coffeemaker',\n",
" 'cookie',\n",
" 'cordon_bleu',\n",
" 'coverall',\n",
" 'crab',\n",
" 'creme_brulee',\n",
" 'crepe',\n",
" 'crib',\n",
" 'croissant',\n",
" 'crow',\n",
" 'cruise_ship',\n",
" 'crumb',\n",
" 'cupcake',\n",
" 'dagger',\n",
" 'dalmatian',\n",
" 'dessert',\n",
" 'dragonfly',\n",
" 'dreidel',\n",
" 'drum',\n",
" 'duffel_bag',\n",
" 'eagle',\n",
" 'eel',\n",
" 'egg',\n",
" 'elephant',\n",
" 'espresso',\n",
" 'face_mask',\n",
" 'ferry',\n",
" 'flamingo',\n",
" 'folder',\n",
" 'fork',\n",
" 'freezer',\n",
" 'french_horn',\n",
" 'fruit',\n",
" 'garlic',\n",
" 'glove',\n",
" 'golf_cart',\n",
" 'gondola',\n",
" 'goose',\n",
" 'gopher',\n",
" 'gorilla',\n",
" 'grasshopper',\n",
" 'grenade',\n",
" 'hamburger',\n",
" 'hammer',\n",
" 'handbrake',\n",
" 'headscarf',\n",
" 'highchair',\n",
" 'hoodie',\n",
" 'hummingbird',\n",
" 'ice_cube',\n",
" 'ice_pack',\n",
" 'jeep',\n",
" 'jelly_bean',\n",
" 'jukebox',\n",
" 'kettle',\n",
" 'kneepad',\n",
" 'ladle',\n",
" 'lamb',\n",
" 'lampshade',\n",
" 'laundry_basket',\n",
" 'lettuce',\n",
" 'lightning_bug',\n",
" 'manatee',\n",
" 'marijuana',\n",
" 'meatloaf',\n",
" 'metal_detector',\n",
" 'minivan',\n",
" 'modem',\n",
" 'mosquito',\n",
" 'muff',\n",
" 'music_box',\n",
" 'mussel',\n",
" 'nightstand',\n",
" 'okra',\n",
" 'omelet',\n",
" 'onion',\n",
" 'orange',\n",
" 'orchid',\n",
" 'ostrich',\n",
" 'pajamas',\n",
" 'panther',\n",
" 'paperweight',\n",
" 'pear',\n",
" 'pepper1',\n",
" 'pheasant',\n",
" 'pickax',\n",
" 'pie',\n",
" 'pigeon',\n",
" 'piglet',\n",
" 'pocket',\n",
" 'pocketknife',\n",
" 'popcorn',\n",
" 'popsicle',\n",
" 'possum',\n",
" 'pretzel',\n",
" 'pug',\n",
" 'punch2',\n",
" 'purse',\n",
" 'radish',\n",
" 'raspberry',\n",
" 'recorder',\n",
" 'rhinoceros',\n",
" 'robot',\n",
" 'rooster',\n",
" 'rug',\n",
" 'sailboat',\n",
" 'sandal',\n",
" 'sandpaper',\n",
" 'sausage',\n",
" 'scallion',\n",
" 'scallop',\n",
" 'scooter',\n",
" 'seagull',\n",
" 'seaweed',\n",
" 'seed',\n",
" 'skateboard',\n",
" 'sled',\n",
" 'sleeping_bag',\n",
" 'slide',\n",
" 'slingshot',\n",
" 'snowshoe',\n",
" 'spatula',\n",
" 'spoon',\n",
" 'station_wagon',\n",
" 'stethoscope',\n",
" 'strawberry',\n",
" 'submarine',\n",
" 'suit',\n",
" 't-shirt',\n",
" 'table',\n",
" 'taillight',\n",
" 'tape_recorder',\n",
" 'television',\n",
" 'tiara',\n",
" 'tick',\n",
" 'tomato_sauce',\n",
" 'tongs',\n",
" 'tool',\n",
" 'top_hat',\n",
" 'treadmill',\n",
" 'tube_top',\n",
" 'turkey',\n",
" 'unicycle',\n",
" 'vise',\n",
" 'volleyball',\n",
" 'wallpaper',\n",
" 'walnut',\n",
" 'wheat',\n",
" 'wheelchair',\n",
" 'windshield',\n",
" 'wine',\n",
" 'wok']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"texts"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of parameters: 2630913\n",
"self.subjects ['sub-08']\n",
"exclude_subject None\n",
"Data tensor shape: torch.Size([200, 63, 250]), label tensor shape: torch.Size([200]), text length: 200, image length: 200\n",
"features_tensor torch.Size([200, 1024])\n",
" - Test Loss: 17.3687, Test Accuracy: 0.2900\n"
]
}
],
"source": [
"import os\n",
"\n",
"import torch\n",
"import torch.optim as optim\n",
"from torch.nn import CrossEntropyLoss\n",
"from torch.nn import functional as F\n",
"from torch.optim import Adam\n",
"from torch.utils.data import DataLoader\n",
"# os.environ['http_proxy'] = 'http://10.16.35.10:13390' \n",
"# os.environ['https_proxy'] = 'http://10.16.35.10:13390' \n",
"\n",
"os.environ[\"WANDB_API_KEY\"] = \"KEY\"\n",
"os.environ[\"WANDB_MODE\"] = 'offline'\n",
"from itertools import combinations\n",
"\n",
"import clip\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torchvision.transforms as transforms\n",
"import tqdm\n",
"from eegdatasets_leaveone import EEGDataset\n",
"from eegencoder import eeg_encoder\n",
"from einops.layers.torch import Rearrange, Reduce\n",
"from lavis.models.clip_models.loss import ClipLoss\n",
"from sklearn.metrics import confusion_matrix\n",
"from torch.utils.data import DataLoader, Dataset\n",
"import random\n",
"from utils import wandb_logger\n",
"from torch import Tensor\n",
"import math\n",
"\n",
"\n",
"class PatchEmbedding(nn.Module):\n",
" def __init__(self, emb_size=40):\n",
" super().__init__()\n",
" # revised from shallownet\n",
" self.tsconv = nn.Sequential(\n",
" nn.Conv2d(1, 40, (1, 25), (1, 1)),\n",
" nn.AvgPool2d((1, 51), (1, 5)),\n",
" nn.BatchNorm2d(40),\n",
" nn.ELU(),\n",
" nn.Conv2d(40, 40, (63, 1), (1, 1)),\n",
" nn.BatchNorm2d(40),\n",
" nn.ELU(),\n",
" nn.Dropout(0.5),\n",
" )\n",
"\n",
" self.projection = nn.Sequential(\n",
" nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), \n",
" Rearrange('b e (h) (w) -> b (h w) e'),\n",
" )\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" # b, _, _, _ = x.shape\n",
" x = x.unsqueeze(1) \n",
" # print(\"x\", x.shape) \n",
" x = self.tsconv(x)\n",
" # print(\"tsconv\", x.shape) \n",
" x = self.projection(x)\n",
" # print(\"projection\", x.shape) \n",
" return x\n",
"\n",
"\n",
"class ResidualAdd(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
"\n",
" def forward(self, x, **kwargs):\n",
" res = x\n",
" x = self.fn(x, **kwargs)\n",
" x += res\n",
" return x\n",
"\n",
"\n",
"class FlattenHead(nn.Sequential):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" def forward(self, x):\n",
" x = x.contiguous().view(x.size(0), -1)\n",
" return x\n",
"\n",
"\n",
"class Enc_eeg(nn.Sequential):\n",
" def __init__(self, emb_size=40, **kwargs):\n",
" super().__init__(\n",
" PatchEmbedding(emb_size),\n",
" FlattenHead()\n",
" )\n",
"\n",
" \n",
"class Proj_eeg(nn.Sequential):\n",
" def __init__(self, embedding_dim=1440, proj_dim=1024, drop_proj=0.5):\n",
" super().__init__(\n",
" nn.Linear(embedding_dim, proj_dim),\n",
" ResidualAdd(nn.Sequential(\n",
" nn.GELU(),\n",
" nn.Linear(proj_dim, proj_dim),\n",
" nn.Dropout(drop_proj),\n",
" )),\n",
" nn.LayerNorm(proj_dim),\n",
" )\n",
"\n",
"\n",
"class Proj_img(nn.Sequential):\n",
" def __init__(self, embedding_dim=1024, proj_dim=1024, drop_proj=0.3):\n",
" super().__init__(\n",
" nn.Linear(embedding_dim, proj_dim),\n",
" ResidualAdd(nn.Sequential(\n",
" nn.GELU(),\n",
" nn.Linear(proj_dim, proj_dim),\n",
" nn.Dropout(drop_proj),\n",
" )),\n",
" nn.LayerNorm(proj_dim),\n",
" )\n",
" def forward(self, x):\n",
" return x \n",
"\n",
"class NICE(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.enc_eeg = Enc_eeg()\n",
" self.proj_eeg = Proj_eeg()\n",
" self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n",
" self.loss_func = ClipLoss() \n",
" def forward(self, data):\n",
" eeg_embedding = self.enc_eeg(data)\n",
" out = self.proj_eeg(eeg_embedding)\n",
"\n",
" return out \n",
" \n",
"\n",
"def get_eegfeatures(sub, eegmodel, dataloader, device, text_features_all, img_features_all, k, mode):\n",
" eegmodel.eval()\n",
" text_features_all = text_features_all.to(device).float()\n",
" img_features_all = img_features_all.to(device).float()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
" alpha =0.9\n",
" top5_correct = 0\n",
" top5_correct_count = 0\n",
"\n",
" all_labels = set(range(text_features_all.size(0)))\n",
" top5_acc = 0\n",
" mse_loss_fn = nn.MSELoss()\n",
" ridge_lambda = 0.1\n",
" save_features = True\n",
" features_list = [] # List to store features \n",
" with torch.no_grad():\n",
" for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):\n",
" eeg_data = eeg_data.to(device)\n",
" eeg_data = eeg_data[:, :, :250]\n",
" # print(\"eeg_data\", eeg_data.shape)\n",
" text_features = text_features.to(device).float()\n",
" labels = labels.to(device)\n",
" img_features = img_features.to(device).float()\n",
" eeg_features = eegmodel(eeg_data).float()\n",
" features_list.append(eeg_features)\n",
" logit_scale = eegmodel.logit_scale \n",
" \n",
" regress_loss = mse_loss_fn(eeg_features, img_features)\n",
" # print(\"eeg_features\", eeg_features.shape)\n",
" # print(torch.std(eeg_features, dim=-1))\n",
" # print(torch.std(img_features, dim=-1))\n",
" # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
" # loss = (regress_loss + ridge_lambda * l2_norm) \n",
" img_loss = eegmodel.loss_func(eeg_features, img_features, logit_scale)\n",
" text_loss = eegmodel.loss_func(eeg_features, text_features, logit_scale)\n",
" contrastive_loss = img_loss\n",
" # loss = img_loss + text_loss\n",
"\n",
" regress_loss = mse_loss_fn(eeg_features, img_features)\n",
" # print(\"text_loss\", text_loss)\n",
" # print(\"img_loss\", img_loss)\n",
" # print(\"regress_loss\", regress_loss) \n",
" # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
" # loss = (regress_loss + ridge_lambda * l2_norm) \n",
" loss = alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10\n",
" # print(\"loss\", loss)\n",
" total_loss += loss.item()\n",
" \n",
" for idx, label in enumerate(labels):\n",
"\n",
" possible_classes = list(all_labels - {label.item()})\n",
" selected_classes = random.sample(possible_classes, k-1) + [label.item()]\n",
" selected_img_features = img_features_all[selected_classes]\n",
" \n",
"\n",
" logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T\n",
" # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T\n",
" # logits_single = (logits_text + logits_img) / 2.0\n",
" logits_single = logits_img\n",
" # print(\"logits_single\", logits_single.shape)\n",
"\n",
" # predicted_label = selected_classes[torch.argmax(logits_single).item()]\n",
" predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \\in {0, 1, ..., n_cls-1}\n",
" if predicted_label == label.item():\n",
" correct += 1 \n",
" total += 1\n",
"\n",
" if save_features:\n",
" features_tensor = torch.cat(features_list, dim=0)\n",
" print(\"features_tensor\", features_tensor.shape)\n",
" torch.save(features_tensor.cpu(), f\"emb_eeg/{config['encoder_type']}_eeg_features_{sub}_{mode}.pt\") # Save features as .pt file\n",
" average_loss = total_loss / (batch_idx+1)\n",
" accuracy = correct / total\n",
" return average_loss, accuracy, labels, features_tensor.cpu()\n",
"\n",
"from IPython.display import Image, display\n",
"config = {\n",
"\"data_path\": \"/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz\",\n",
"\"project\": \"train_pos_img_text_rep\",\n",
"\"entity\": \"sustech_rethinkingbci\",\n",
"\"name\": \"lr=3e-4_img_pos_pro_eeg\",\n",
"\"lr\": 3e-4,\n",
"\"epochs\": 50,\n",
"\"batch_size\": 1024,\n",
"\"logger\": True,\n",
"\"encoder_type\":'NICE',\n",
"}\n",
"\n",
"device = torch.device(\"cuda:6\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"data_path = config['data_path']\n",
"emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')\n",
"emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')\n",
"\n",
"eeg_model = NICE()\n",
"print('number of parameters:', sum([p.numel() for p in eeg_model.parameters()]))\n",
"\n",
"#####################################################################################\n",
"\n",
"# eeg_model.load_state_dict(torch.load(\"/home/ldy/Workspace/Reconstruction/models/contrast/sub-08/01-30_00-44/40.pth\"))\n",
"eeg_model.load_state_dict(torch.load(\"/home/ldy/Workspace/BrainAligning_retrieval/models/contrast/NICE/sub-08/02-10_01-58/best.pth\"))\n",
"eeg_model = eeg_model.to(device)\n",
"sub = 'sub-08'\n",
"\n",
"#####################################################################################\n",
"\n",
"test_dataset = EEGDataset(data_path, subjects= [sub], train=False)\n",
"test_loader = DataLoader(test_dataset, batch_size=config[\"batch_size\"], shuffle=False, num_workers=0)\n",
"text_features_test_all = test_dataset.text_features\n",
"img_features_test_all = test_dataset.img_features\n",
"test_loss, test_accuracy,labels, eeg_features_test = get_eegfeatures(sub, eeg_model, test_loader, device, text_features_test_all, img_features_test_all,k=200, mode=\"test\")\n",
"print(f\" - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"self.subjects ['sub-08']\n",
"exclude_subject None\n",
"data_tensor torch.Size([66160, 63, 250])\n",
"Data tensor shape: torch.Size([66160, 63, 250]), label tensor shape: torch.Size([66160]), text length: 1654, image length: 16540\n",
"features_tensor torch.Size([66160, 1024])\n",
" - Test Loss: 13.3945, Test Accuracy: 0.0049\n"
]
}
],
"source": [
"#####################################################################################\n",
"train_dataset = EEGDataset(data_path, subjects= [sub], train=True)\n",
"train_loader = DataLoader(train_dataset, batch_size=config[\"batch_size\"], shuffle=False, num_workers=0)\n",
"text_features_test_all = train_dataset.text_features\n",
"img_features_test_all = train_dataset.img_features\n",
"\n",
"train_loss, train_accuracy, labels, eeg_features_train = get_eegfeatures(sub, eeg_model, train_loader, device, text_features_test_all, img_features_test_all,k=200, mode=\"train\")\n",
"print(f\" - Test Loss: {train_loss:.4f}, Test Accuracy: {train_accuracy:.4f}\")\n",
"#####################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"import open_clip\n",
"from matplotlib.font_manager import FontProperties\n",
"\n",
"import sys\n",
"from diffusion_prior import *\n",
"from custom_pipeline import *\n",
"# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
"device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)\n",
"emb_eeg = torch.load(f\"/home/ldy/Workspace/Generation/emb_eeg/{config['encoder_type']}_eeg_features_{sub}_train.pt\")\n",
"emb_eeg_test = torch.load(f\"/home/ldy/Workspace/Generation/emb_eeg/{config['encoder_type']}_eeg_features_{sub}_test.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([66160, 1024]), torch.Size([200, 1024]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb_eeg.shape, emb_eeg_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4268, -0.6550, -0.3994, ..., -0.0573, -1.2790, 1.7172],\n",
" [-2.5021, 3.2850, -0.1402, ..., -0.4890, -0.6232, 0.1757],\n",
" [-1.8498, -0.1019, -1.6115, ..., -0.4604, 0.1076, 0.9637],\n",
" ...,\n",
" [-2.1374, 0.4877, -0.5477, ..., -1.2601, 0.0205, 2.3434],\n",
" [-1.7127, 1.2059, 0.3979, ..., -1.9886, -1.5540, 1.5707],\n",
" [ 1.2487, 1.1401, -0.4919, ..., -1.6411, -0.8680, 0.6948]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eeg_features_train"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from diffusion_prior import *\n",
"class EmbeddingDataset(Dataset):\n",
"\n",
" def __init__(self, c_embeddings=None, h_embeddings=None, h_embeds_uncond=None, cond_sampling_rate=0.5):\n",
" self.c_embeddings = c_embeddings\n",
" self.h_embeddings = h_embeddings\n",
" self.N_cond = 0 if self.h_embeddings is None else len(self.h_embeddings)\n",
" self.h_embeds_uncond = h_embeds_uncond\n",
" self.N_uncond = 0 if self.h_embeds_uncond is None else len(self.h_embeds_uncond)\n",
" self.cond_sampling_rate = cond_sampling_rate\n",
"\n",
" def __len__(self):\n",
" return self.N_cond\n",
"\n",
" def __getitem__(self, idx):\n",
" return {\n",
" \"c_embedding\": self.c_embeddings[idx],\n",
" \"h_embedding\": self.h_embeddings[idx]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9675648\n"
]
}
],
"source": [
"dataset = EmbeddingDataset(\n",
" c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, \n",
" # h_embeds_uncond=h_embeds_imgnet\n",
")\n",
"dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)\n",
"diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
"# number of parameters\n",
"print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
"pipe = Pipe(diffusion_prior, device=device)\n",
"\n",
"# load pretrained model\n",
"model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
"# pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7fddc3c70d17482ea3b77fa54f9f6cee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 98.87it/s] \n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b2d27412e1346028eb74b3a11047123",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 217.87it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e86a53883d8641b98241c8173d4f3fa6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.24it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "23aebe21cbcc40f9ac8d7e275e038b09",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.15it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7354989792724b0e8d5d34c8293ffc4a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.05it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3070e1bdc26f4f4c8e034f2cc84a895f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "94a1f45755f740beae6cba4a606ae26b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a6fe6629d524e05aed078ac3c3f972a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.61it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "112147daf7954471a85009cfd551bd47",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 209.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f58f22d60304d0295d3f7775ddec3b8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.57it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c880754c7668453a8a8b0ac0c93f2536",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.37it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fed582d6b7074e17bb57b04408a1c919",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.93it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c3b10f2d3b8644d1a4bbef890ac1286d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.48it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39bfafdaeef2433eace4c0cc71e765ac",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0fbd09e94f0454ba429f9a1db7a8078",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.04it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1f18f9d18e1a4938950061c6304fbd37",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.08it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4fb9d223acd4ef7b68e1a5b676688bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.66it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7d9b5da1ed924b95a3cd1adf17cd66a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 226.69it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "52c3e07e27bd4eec8eb6610a7dc9311b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.62it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "142ce270235f432f91b660efd89c21a1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7b3d13be932d46448e966f7f1660a695",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.89it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9c07be31e704c7e970c7792ab21fea7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.94it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "76c194f7cd394dcc8d3a82647f5148de",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.84it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "45c4449256fd4bbabd9b54fb39ade163",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 212.70it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "96c1a6a4eff646eb902a971ce763a796",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.02it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4e5a7aedb7844deafcedb0113806e92",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.16it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf8710d489184cd8a04d1041ccb59897",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 130.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fda709ff96b3451fb5460650811a536c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.86it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e32d215148347bfa21b052553e4abc4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.93it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a13dd00c59c04c558f8339c350f15e62",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.56it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ce4fe31a6af445ecb07825d0e037fd20",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.67it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "de7bc89f0a8a4c66bf928cb6e13049b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.10it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6f768a0d879449d942dcf09854c0ab5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.39it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1cbf15cdf60e4fb09fca034d7cef4239",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.95it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aee609fcc94347da82bc30e51bfccfb9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.51it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "67562ae308954d4fbaf93e093934aed2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 220.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c46d9b017a8246eb8fb3912955e10c8e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.07it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "01e329e3fa5044919123fb278eab9183",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.68it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "269f733181504252acf54170dee4080d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 224.77it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7270532468c94d7282751d8ea3a7d4dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 235.16it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6cf48792cbb4ecb87fd04ee61fcc500",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 239.84it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dab0f916b6a74e59b563ddf4170852e2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 203.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2e034529e7454af197335cbe5289383d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.43it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "913bb9b99afa45a8a3126ab3f72d5074",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.67it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8bea9630fcf34ec8ac588de432aac4f7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.66it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a38b00d9b7834219bc53829a6a9e052f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.73it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "35451dfe04974cdf8eaddf69ced69795",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.38it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ab0e84e7f244b888ebd9ef8e4ee172a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 220.76it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e80463b825bf4203aabebbd4ae022f6f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.30it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3123438a3e49467c92af76bfce840e59",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.76it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6b912246c5e48238f12f66aac94f2ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3a1232e908fa47a5a41f1ab647d73859",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a85a22fe302448f3bade4db13cbf259b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.21it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8d9ac37b9c1c45d3b13bd0147fc5980e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.25it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "45e6611eb3aa411d98b4001135906750",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.04it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a263483772824bbfb39a995aa41731b8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 123.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57f0ee03f446402fa2bdc88f499b1d61",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f036764b66df4758b8db6a0e677912ea",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.15it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "100b76d40a4b46aebd79cdf2e7d5f956",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6a8f0d14ec9f4f0fac2faacc30a91fef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.61it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d652371cfec49d5b0fd6867f319de11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bcbf9fee2fa54fc280cbe6598f77e094",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.70it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "74547b7901b445e78cc08e7c6f6047e1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.09it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "419f90efe714463993fff796cb4d782d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.81it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2dbfd2d6580a40dba56e06db7d5ac0d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 219.23it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "49e769b9526c408aa3b1fb43f0220a76",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 218.76it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc63e3dfd61c49e19a6c65f13552fb15",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.10it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d76d5827c19a46e593845dd31837c9d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.15it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8e6cc4c8aaae4464ad2d28019aeb0cd0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.90it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "43751fad3c4349809e1eed1b2a0ce18b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.36it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e3e1f0ce131a40f0828688d5c84847dd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 226.52it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ce4782f75ac43c19836ca74ee10eced",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 239.09it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "33bf7d033479450a98024f51d5635fd9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "629b26bb67804416b60f7fb311438688",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 222.25it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b78607be617d4509a975f26e00148dfb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.99it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36c777c98d6b4cb0ab6f2fe68ee6116a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.61it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b6093c0fbb6747b997523478ad500cb3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a8f61148f51144bba94902c65fe475f9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 206.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8caabef2056a4d0a8e98d340f8644041",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.74it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b8044760e8d047e6ba8adc5b535c118b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.54it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "279df6b9f0474c9d9a02d9d1052f6309",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 215.06it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e3448e2852c406599be7eae3c1fb8ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.97it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ac2adf407d88456da99c2ec0c6cf63c1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.42it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b82c1f6da30c4c3ba5dcbab43e211513",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.92it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88cd06a5b7a24eeb90e3a37b061ed3cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.32it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d38475bc10c4eed843dde39ec50eb31",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.46it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf595086471c45b0bbc54668ef708c5d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 126.07it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c5f8cf4465f445b901114c7b1768cec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 219.05it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8c64e065f80b401680b7fd3a413e9f2b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4ca44ee73aa24c908dd72d8917495082",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc3f2059c26c49be9effbbfe0ffcef43",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.33it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "53f30f3cdc1743f6ab87dcc659436fa6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.20it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b0056af748b1423192815fc7ff3590a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.90it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "13ec26534eac44e080ec3a4df2a829a6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "72cf368365ce4132895da52c9298c950",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.30it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "19ca74683ef94a498c8a839b71a8d1c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.53it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ecc04d4e9b2942dd8fa8562f4f2febc4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.83it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ab9734540c74935bfefe033951dba83",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 231.04it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "29a17d9dec3e4151a6d612d8a77a2f4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.97it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "140b792e4ae046cb832fa003508e95cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.05it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9920ac55ad3d424a9122f15f98a7821f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 211.55it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6f4732fc071345b39f33d969bb52128c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "20067e2d3afb4eaa90da8ccd35bcedda",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 239.81it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31496161c37a4084817d15172e6dab1a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.37it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5fe579a04ad14603a366712a3118f910",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 213.37it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "05441a838f474bef885946a6ce82ad5c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.75it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26da09b8c82f4c37b33e3edf495540a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 166.38it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3552167f3f704518968cafb1de9ad8b5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 218.25it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "855ff4bc1d1840d19561ef47a2ed43f5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.46it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "45872dbcb2fd4dfb94693c9e602e8901",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.33it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "be1337429e724ea8be729d04c4fea648",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.43it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4510c220801c4a6e9e506347a38f9ffd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.59it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3d30e05785354b23921dfcc0bda9787e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.59it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8cf2f8e341da49f9ae95c0924018a20b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 219.79it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2995441000ec4122b31558e41c2aba95",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.32it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "19de6e9c3235464c84a984bb6d7a745f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.25it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3fd57fe7218045e0829707d63d3e051c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.54it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6800df0d2c4d43b3990376297f770b3c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 127.69it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8108bfe1cd99487985df7d22756ed83c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 218.39it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1efa091883ab456a9ee6f0967ed18df8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 223.77it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8a9aecbca378486d96e118ccae9e0518",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.79it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d935c5f3291c405ba86d16dd040c61d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 235.40it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ece5e8fef184fc1aba1e1d577c5b0ca",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.68it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b402788f32c241178174f3824e557513",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 204.60it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46283e85d0e0437b9a2682a171e13015",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.07it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4c2072cd70c24d79a395710625932f3a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.11it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28c08eef85bb4f80af93f8b9eaecb7f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 220.24it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cd68f22484f644f4a0f120f012143f23",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.79it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10a5ed11f0104ecb91231d4a1ba94470",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 239.70it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9a42af373ecc453d8e0835abf241fa1d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 218.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1a02927bb8c14b8ea8476aecb5cf9762",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f508cd4279454d2c8f838d7daffe0a47",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 208.33it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e04d42f132b94d73b68637ee7633f50f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.19it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bce736aba25b4ff6acfb05b3a5c0f5f5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.66it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "93e9d7a0d4c24d41a788442f4cc30b50",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "43f8c8c52ecb48e6a7ebcd75ebfe9b5b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.44it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2fc796a2095f45dfbd1e3b831d61ef4d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.38it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8bcb41985a784009bdbad31b4150d8fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 217.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cddae053e4574aafa178f47cecd1c188",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.39it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a4b39e3a8bc2453e96408ec6793a8336",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.69it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8f35f5cf7f984f4ea364399a10435d4d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 211.38it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "72deb09c8d794c66aa7f933d2b5155e2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.19it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "01a09d57a53e44739486b1d4b9e42903",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 219.73it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99adba8266a1419ca24027d39c1cd303",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.80it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "49b49b5ed8544099aabc18d7da2005bb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.32it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a7457589a46f484a843bc9a43b85d4f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.40it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1bb162da15a04fd08daf168601d16168",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.32it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed17cb88a0fd4ec68b3f7b644a90bf71",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.61it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc46dee282f04e29b910126ab324b9d2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 114.27it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "faec1fbc33e1487a8e6c6ab38cd8ddc2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 223.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "21a4e6fd4457489382491e31255ac734",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 220.86it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "463555fa88074b27ae9a5845ec1ed68a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.87it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1378021bf2224354a79f4a947c2a2442",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.85it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18306a9903034f759845ab60264a264b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.89it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39a90e6ebf3e44588ed7880c7de6c0be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.74it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1f4024add55449c6bc7b694aa146f7ac",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 223.39it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fe49537103ce4c09957d8630694268d0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.81it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2a553028483a4e5bb13d4b45e35488cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.46it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a76ae9e4e084dd892234ee694817d43",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.47it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0db5bb28cb6246ef996d17b05311d4a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.97it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d1d8d19f33fa4eac8a3da3089ff5ddba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.50it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "148db84915a5469290619cb2e68e9973",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 226.23it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b38f68ff7b794fff87fc6a7fb941bc89",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.10it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9cbe74ac465947558ba77e54aa8f76d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.16it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "81583f6ff5074f1483e10e3201186602",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.44it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "511188411be346a495b1f9578aab9b27",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.44it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5fe76e0b2ac7419ab02cc529a08d84a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.96it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bf530244a5e94c10ad34de387965ea73",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb80ccb6bf334963b94e06ef9a93db46",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 235.71it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dd632060b8a140ac993867ecb7d76c19",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 211.61it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e1b880f93ff146a6b070eb7dce50bb81",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.13it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "81ae3de19e6742faaeec0cd872cd59b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c9cc46afe2374282ae7aeca44c21df78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 208.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "190a9e5878aa4858bd2dfbb2d0c93e3c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 221.01it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a7c322d357d849f18d3f1e2625b2421b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 224.46it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8c0efc551d1e4bb3a4b87bf727d00dc9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 206.80it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "812bfee495e24cba9ee92d5d19d49021",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.33it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c3dadc01506c4adf87e928c0ee417f4c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7932ccc706ee4a7d9256c2166b3f9709",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 222.82it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "65764d7d9005446cbb4db969f9ce3ddd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 111.33it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4c87a828e9f34391a1aad5e740c92e0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.40it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ddddc9e841bf417fb4d6a640b05c7aee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 212.68it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c7d155d00e45487da2aac49888069a47",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38097cdedbc74b48aa569f405920f3a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "21e731d2d7384531b4592c9febabbc62",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 218.64it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66937d1f91d147e99ecb989b09b2ba09",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 231.02it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57d1ba179ab348f3941f6b43cf7d1cee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.43it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "20c377c3f3de4d2d90cfe6337ef276ab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 223.99it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12e90c04b5a1419abf36408c781fd7f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.53it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "613b78dccf164165a7b51c47f4f5b370",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.04it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ad6accc870345d7b14b726e08a2efb6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.21it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "de7e44446aaa41b781931b08bfd4ad95",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.42it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46b0fcbc4df54b7d81e799336d5de0bc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.56it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "668cf63442534801879e92592494b50e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.28it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ac548dd0679c4743970df0c44bc8f731",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.58it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e00dea7b03684846ba42d9faae1045cf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.17it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf9e58c990d8440c971de073adb553e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.86it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4e7791c9033f4a10b5a8dc0e62b602fe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.48it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26ea8c951d274f72a1672ade4e7ada43",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8974e365434e4f40863cfd2fa9ebd4fb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.17it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18b235bad7a4492092fa35205f57e9fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"pipe.diffusion_prior.load_state_dict(torch.load(f\"./fintune_ckpts/{config['encoder_type']}/{sub}/{model_name}.pt\", map_location=device))\n",
"# save_path = f'./fintune_ckpts/{config[\"encoder_type\"]}/{sub}/{model_name}.pt'\n",
"\n",
"# directory = os.path.dirname(save_path)\n",
"\n",
"# Create the directory if it doesn't exist\n",
"# os.makedirs(directory, exist_ok=True)\n",
"# torch.save(pipe.diffusion_prior.state_dict(), save_path)\n",
"from PIL import Image\n",
"import os\n",
"\n",
"# Assuming generator.generate returns a PIL Image\n",
"generator = Generator4Embeds(num_inference_steps=4, device=device)\n",
"\n",
"directory = f\"Generation/{config['encoder_type']}/generated_imgs/{sub}\"\n",
"for k in range(200):\n",
" eeg_embeds = emb_eeg_test[k:k+1]\n",
" h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)\n",
" for j in range(1):\n",
" image = generator.generate(h.to(dtype=torch.float16))\n",
" # Construct the save path for each image\n",
" path = f'{directory}/{texts[k]}/{j}.png'\n",
" # Ensure the directory exists\n",
" os.makedirs(os.path.dirname(path), exist_ok=True)\n",
" # Save the PIL Image\n",
" image.save(path)\n",
" # print(f'Image saved to {path}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import os\n",
"# # os.environ['http_proxy'] = 'http://10.16.35.10:13390' \n",
"# # os.environ['https_proxy'] = 'http://10.16.35.10:13390' \n",
"# # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
"# # os.environ['PATH'] += os.pathsep + '/usr/local/texlive/2023/bin/x86_64-linux'\n",
"\n",
"# import torch\n",
"# from torch import nn\n",
"# import torch.nn.functional as F\n",
"# import torchvision.transforms as transforms\n",
"# import matplotlib.pyplot as plt\n",
"# import open_clip\n",
"# from matplotlib.font_manager import FontProperties\n",
"\n",
"# import sys\n",
"# from diffusion_prior import *\n",
"# from custom_pipeline import *\n",
"\n",
"# device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load eeg and image embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # load image embeddings\n",
"# data = torch.load('/home/ldy/Workspace/THINGS/CLIP/ViT-H-14_features_test.pt', map_location='cuda:3')\n",
"# emb_img_test = data['img_features']\n",
"\n",
"# # load image embeddings\n",
"# data = torch.load('/home/ldy/Workspace/THINGS/CLIP/ViT-H-14_features_train.pt', map_location='cuda:3')\n",
"# emb_img_train = data['img_features']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')\n",
"# emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')\n",
"\n",
"# torch.save(emb_img_test.cpu().detach(), 'variables/ViT-H-14_features_test.pt')\n",
"# torch.save(emb_img_train.cpu().detach(), 'variables/ViT-H-14_features_train.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# emb_img_test.shape, emb_img_train.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1654clsx10imgsx4trials=66160\n",
"# emb_eeg = torch.load('/home/ldy/Workspace/Reconstruction/ATM_S_eeg_features_sub-08.pt')\n",
"\n",
"# emb_eeg_test = torch.load('/home/ldy/Workspace/Reconstruction/ATM_S_eeg_features_sub-08_test.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# emb_eeg.shape, emb_eeg_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training prior diffusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class EmbeddingDataset(Dataset):\n",
"\n",
" def __init__(self, c_embeddings=None, h_embeddings=None, h_embeds_uncond=None, cond_sampling_rate=0.5):\n",
" self.c_embeddings = c_embeddings\n",
" self.h_embeddings = h_embeddings\n",
" self.N_cond = 0 if self.h_embeddings is None else len(self.h_embeddings)\n",
" self.h_embeds_uncond = h_embeds_uncond\n",
" self.N_uncond = 0 if self.h_embeds_uncond is None else len(self.h_embeds_uncond)\n",
" self.cond_sampling_rate = cond_sampling_rate\n",
"\n",
" def __len__(self):\n",
" return self.N_cond\n",
"\n",
" def __getitem__(self, idx):\n",
" return {\n",
" \"c_embedding\": self.c_embeddings[idx],\n",
" \"h_embedding\": self.h_embeddings[idx]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"emb_img_train_4.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# path_data = '/mnt/dataset0/weichen/projects/visobj/proposals/mise/data'\n",
"# image_features = torch.load(os.path.join(path_data, 'openclip_emb/emb_imgnet.pt')) # 'emb_imgnet' or 'image_features'\n",
"# h_embeds_imgnet = image_features['image_features']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"dataset = EmbeddingDataset(\n",
" c_embeddings=emb_eeg, h_embeddings=emb_img_train_4, \n",
" # h_embeds_uncond=h_embeds_imgnet\n",
")\n",
"print(len(dataset))\n",
"dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# diffusion_prior = DiffusionPrior(dropout=0.1)\n",
"diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
"# number of parameters\n",
"print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
"pipe = Pipe(diffusion_prior, device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load pretrained model\n",
"model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
"pipe.diffusion_prior.load_state_dict(torch.load(f'./ckpts/{model_name}.pt', map_location=device))\n",
"# pipe.train(dataloader, num_epochs=150, learning_rate=1e-3) # to 0.142 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save model\n",
"# torch.save(pipe.diffusion_prior.state_dict(), f'./ckpts/{model_name}.pt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generating by eeg embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save model\n",
"# torch.save(pipe.diffusion_prior.state_dict(), f'./ckpts/{model_name}.pt')\n",
"from IPython.display import Image, display\n",
"generator = Generator4Embeds(num_inference_steps=4, device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# path of ground truth: /home/ldy/Workspace/THINGS/images_set/test_images\n",
"k = 99\n",
"image_embeds = emb_img_test[k:k+1]\n",
"print(\"image_embeds\", image_embeds.shape)\n",
"image = generator.generate(image_embeds)\n",
"display(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# image_embeds = emb_eeg_test[k:k+1]\n",
"# print(\"image_embeds\", image_embeds.shape)\n",
"# image = generator.generate(image_embeds)\n",
"# display(image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generating by eeg informed image embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# k = 0\n",
"eeg_embeds = emb_eeg_test[k:k+1]\n",
"print(\"image_embeds\", eeg_embeds.shape)\n",
"h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)\n",
"image = generator.generate(h.to(dtype=torch.float16))\n",
"display(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "BCI",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}