5402 lines (5401 with data), 134.8 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import numpy as np\n",
"import os\n",
"import clip\n",
"from torch.nn import functional as F\n",
"import torch.nn as nn\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"train = False\n",
"classes = None\n",
"pictures= None\n",
"\n",
"def load_data():\n",
" data_list = []\n",
" label_list = []\n",
" texts = []\n",
" images = []\n",
" \n",
" if train:\n",
" text_directory = \"/home/ldy/Workspace/THINGS/images_set/training_images\" \n",
" else:\n",
" text_directory = \"/home/ldy/Workspace/THINGS/images_set/test_images\"\n",
"\n",
" dirnames = [d for d in os.listdir(text_directory) if os.path.isdir(os.path.join(text_directory, d))]\n",
" dirnames.sort()\n",
" \n",
" if classes is not None:\n",
" dirnames = [dirnames[i] for i in classes]\n",
"\n",
" for dir in dirnames:\n",
"\n",
" try:\n",
" idx = dir.index('_')\n",
" description = dir[idx+1:]\n",
" except ValueError:\n",
" print(f\"Skipped: {dir} due to no '_' found.\")\n",
" continue\n",
" \n",
" new_description = f\"{description}\"\n",
" texts.append(new_description)\n",
"\n",
" if train:\n",
" img_directory = \"/home/ldy/Workspace/THINGS/images_set/training_images\"\n",
" else:\n",
" img_directory =\"/home/ldy/Workspace/THINGS/images_set/test_images\"\n",
" \n",
" all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))]\n",
" all_folders.sort()\n",
"\n",
" if classes is not None and pictures is not None:\n",
" images = []\n",
" for i in range(len(classes)):\n",
" class_idx = classes[i]\n",
" pic_idx = pictures[i]\n",
" if class_idx < len(all_folders):\n",
" folder = all_folders[class_idx]\n",
" folder_path = os.path.join(img_directory, folder)\n",
" all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
" all_images.sort()\n",
" if pic_idx < len(all_images):\n",
" images.append(os.path.join(folder_path, all_images[pic_idx]))\n",
" elif classes is not None and pictures is None:\n",
" images = []\n",
" for i in range(len(classes)):\n",
" class_idx = classes[i]\n",
" if class_idx < len(all_folders):\n",
" folder = all_folders[class_idx]\n",
" folder_path = os.path.join(img_directory, folder)\n",
" all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
" all_images.sort()\n",
" images.extend(os.path.join(folder_path, img) for img in all_images)\n",
" elif classes is None:\n",
" images = []\n",
" for folder in all_folders:\n",
" folder_path = os.path.join(img_directory, folder)\n",
" all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
" all_images.sort() \n",
" images.extend(os.path.join(folder_path, img) for img in all_images)\n",
" else:\n",
"\n",
" print(\"Error\")\n",
" return texts, images\n",
"texts, images = load_data()\n",
"# images"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['aircraft_carrier',\n",
" 'antelope',\n",
" 'backscratcher',\n",
" 'balance_beam',\n",
" 'banana',\n",
" 'baseball_bat',\n",
" 'basil',\n",
" 'basketball',\n",
" 'bassoon',\n",
" 'baton4',\n",
" 'batter',\n",
" 'beaver',\n",
" 'bench',\n",
" 'bike',\n",
" 'birthday_cake',\n",
" 'blowtorch',\n",
" 'boat',\n",
" 'bok_choy',\n",
" 'bonnet',\n",
" 'bottle_opener',\n",
" 'brace',\n",
" 'bread',\n",
" 'breadbox',\n",
" 'bug',\n",
" 'buggy',\n",
" 'bullet',\n",
" 'bun',\n",
" 'bush',\n",
" 'calamari',\n",
" 'candlestick',\n",
" 'cart',\n",
" 'cashew',\n",
" 'cat',\n",
" 'caterpillar',\n",
" 'cd_player',\n",
" 'chain',\n",
" 'chaps',\n",
" 'cheese',\n",
" 'cheetah',\n",
" 'chest2',\n",
" 'chime',\n",
" 'chopsticks',\n",
" 'cleat',\n",
" 'cleaver',\n",
" 'coat',\n",
" 'cobra',\n",
" 'coconut',\n",
" 'coffee_bean',\n",
" 'coffeemaker',\n",
" 'cookie',\n",
" 'cordon_bleu',\n",
" 'coverall',\n",
" 'crab',\n",
" 'creme_brulee',\n",
" 'crepe',\n",
" 'crib',\n",
" 'croissant',\n",
" 'crow',\n",
" 'cruise_ship',\n",
" 'crumb',\n",
" 'cupcake',\n",
" 'dagger',\n",
" 'dalmatian',\n",
" 'dessert',\n",
" 'dragonfly',\n",
" 'dreidel',\n",
" 'drum',\n",
" 'duffel_bag',\n",
" 'eagle',\n",
" 'eel',\n",
" 'egg',\n",
" 'elephant',\n",
" 'espresso',\n",
" 'face_mask',\n",
" 'ferry',\n",
" 'flamingo',\n",
" 'folder',\n",
" 'fork',\n",
" 'freezer',\n",
" 'french_horn',\n",
" 'fruit',\n",
" 'garlic',\n",
" 'glove',\n",
" 'golf_cart',\n",
" 'gondola',\n",
" 'goose',\n",
" 'gopher',\n",
" 'gorilla',\n",
" 'grasshopper',\n",
" 'grenade',\n",
" 'hamburger',\n",
" 'hammer',\n",
" 'handbrake',\n",
" 'headscarf',\n",
" 'highchair',\n",
" 'hoodie',\n",
" 'hummingbird',\n",
" 'ice_cube',\n",
" 'ice_pack',\n",
" 'jeep',\n",
" 'jelly_bean',\n",
" 'jukebox',\n",
" 'kettle',\n",
" 'kneepad',\n",
" 'ladle',\n",
" 'lamb',\n",
" 'lampshade',\n",
" 'laundry_basket',\n",
" 'lettuce',\n",
" 'lightning_bug',\n",
" 'manatee',\n",
" 'marijuana',\n",
" 'meatloaf',\n",
" 'metal_detector',\n",
" 'minivan',\n",
" 'modem',\n",
" 'mosquito',\n",
" 'muff',\n",
" 'music_box',\n",
" 'mussel',\n",
" 'nightstand',\n",
" 'okra',\n",
" 'omelet',\n",
" 'onion',\n",
" 'orange',\n",
" 'orchid',\n",
" 'ostrich',\n",
" 'pajamas',\n",
" 'panther',\n",
" 'paperweight',\n",
" 'pear',\n",
" 'pepper1',\n",
" 'pheasant',\n",
" 'pickax',\n",
" 'pie',\n",
" 'pigeon',\n",
" 'piglet',\n",
" 'pocket',\n",
" 'pocketknife',\n",
" 'popcorn',\n",
" 'popsicle',\n",
" 'possum',\n",
" 'pretzel',\n",
" 'pug',\n",
" 'punch2',\n",
" 'purse',\n",
" 'radish',\n",
" 'raspberry',\n",
" 'recorder',\n",
" 'rhinoceros',\n",
" 'robot',\n",
" 'rooster',\n",
" 'rug',\n",
" 'sailboat',\n",
" 'sandal',\n",
" 'sandpaper',\n",
" 'sausage',\n",
" 'scallion',\n",
" 'scallop',\n",
" 'scooter',\n",
" 'seagull',\n",
" 'seaweed',\n",
" 'seed',\n",
" 'skateboard',\n",
" 'sled',\n",
" 'sleeping_bag',\n",
" 'slide',\n",
" 'slingshot',\n",
" 'snowshoe',\n",
" 'spatula',\n",
" 'spoon',\n",
" 'station_wagon',\n",
" 'stethoscope',\n",
" 'strawberry',\n",
" 'submarine',\n",
" 'suit',\n",
" 't-shirt',\n",
" 'table',\n",
" 'taillight',\n",
" 'tape_recorder',\n",
" 'television',\n",
" 'tiara',\n",
" 'tick',\n",
" 'tomato_sauce',\n",
" 'tongs',\n",
" 'tool',\n",
" 'top_hat',\n",
" 'treadmill',\n",
" 'tube_top',\n",
" 'turkey',\n",
" 'unicycle',\n",
" 'vise',\n",
" 'volleyball',\n",
" 'wallpaper',\n",
" 'walnut',\n",
" 'wheat',\n",
" 'wheelchair',\n",
" 'windshield',\n",
" 'wine',\n",
" 'wok']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"texts"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ldy/miniconda3/envs/BCI/lib/python3.10/site-packages/braindecode/models/base.py:23: UserWarning: EEGNetv4: 'in_chans' is depreciated. Use 'n_chans' instead.\n",
" warnings.warn(\n",
"/home/ldy/miniconda3/envs/BCI/lib/python3.10/site-packages/braindecode/models/base.py:23: UserWarning: EEGNetv4: 'n_classes' is depreciated. Use 'n_outputs' instead.\n",
" warnings.warn(\n",
"/home/ldy/miniconda3/envs/BCI/lib/python3.10/site-packages/braindecode/models/base.py:23: UserWarning: EEGNetv4: 'input_window_samples' is depreciated. Use 'n_times' instead.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of parameters: 1186833\n",
"self.subjects ['sub-08']\n",
"exclude_subject None\n",
"Data tensor shape: torch.Size([200, 63, 250]), label tensor shape: torch.Size([200]), text length: 200, image length: 200\n",
"features_tensor torch.Size([200, 1024])\n",
" - Test Loss: 8.4386, Test Accuracy: 0.2300\n"
]
}
],
"source": [
"import os\n",
"\n",
"import torch\n",
"import torch.optim as optim\n",
"from torch.nn import CrossEntropyLoss\n",
"from torch.nn import functional as F\n",
"from torch.optim import Adam\n",
"from torch.utils.data import DataLoader\n",
"# os.environ['http_proxy'] = 'http://10.16.35.10:13390' \n",
"# os.environ['https_proxy'] = 'http://10.16.35.10:13390' \n",
"\n",
"os.environ[\"WANDB_API_KEY\"] = \"KEY\"\n",
"os.environ[\"WANDB_MODE\"] = 'offline'\n",
"from itertools import combinations\n",
"\n",
"import clip\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torchvision.transforms as transforms\n",
"import tqdm\n",
"from eegdatasets_leaveone import EEGDataset\n",
"from eegencoder import eeg_encoder\n",
"from einops.layers.torch import Rearrange, Reduce\n",
"from lavis.models.clip_models.loss import ClipLoss\n",
"from sklearn.metrics import confusion_matrix\n",
"from torch.utils.data import DataLoader, Dataset\n",
"import random\n",
"from utils import wandb_logger\n",
"from torch import Tensor\n",
"import math\n",
"from braindecode.models import EEGNetv4, ATCNet, EEGConformer, EEGITNet, ShallowFBCSPNet\n",
"import csv\n",
"\n",
"\n",
"\n",
"device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"class EEGNetv4_Encoder(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.device = device\n",
" self.shape = (63, 250)\n",
" self.eegnet = EEGNetv4(\n",
" in_chans=self.shape[0],\n",
" n_classes=1024,\n",
" input_window_samples=self.shape[1],\n",
" final_conv_length='auto',\n",
" pool_mode='mean',\n",
" F1=8,\n",
" D=20,\n",
" F2=160,\n",
" kernel_length=4,\n",
" third_kernel_size=(4, 2),\n",
" drop_prob=0.25\n",
" )\n",
" self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n",
" self.loss_func = ClipLoss()\n",
" def forward(self, data):\n",
" data = data.unsqueeze(0)\n",
" data = data.reshape(data.shape[1], data.shape[2], data.shape[3], data.shape[0])\n",
" # print(data.shape)\n",
" prediction = self.eegnet(data)\n",
" return prediction\n",
" \n",
"\n",
"def get_eegfeatures(sub, eegmodel, dataloader, device, text_features_all, img_features_all, k, mode):\n",
" eegmodel.eval()\n",
" text_features_all = text_features_all.to(device).float()\n",
" img_features_all = img_features_all.to(device).float()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
" alpha =0.9\n",
" top5_correct = 0\n",
" top5_correct_count = 0\n",
"\n",
" all_labels = set(range(text_features_all.size(0)))\n",
" top5_acc = 0\n",
" mse_loss_fn = nn.MSELoss()\n",
" ridge_lambda = 0.1\n",
" save_features = True\n",
" features_list = [] # List to store features \n",
" with torch.no_grad():\n",
" for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):\n",
" eeg_data = eeg_data.to(device)\n",
" eeg_data = eeg_data[:, :, :250]\n",
" # print(\"eeg_data\", eeg_data.shape)\n",
" text_features = text_features.to(device).float()\n",
" labels = labels.to(device)\n",
" img_features = img_features.to(device).float()\n",
" eeg_features = eegmodel(eeg_data).float()\n",
" features_list.append(eeg_features)\n",
" logit_scale = eegmodel.logit_scale \n",
" \n",
" regress_loss = mse_loss_fn(eeg_features, img_features)\n",
" # print(\"eeg_features\", eeg_features.shape)\n",
" # print(torch.std(eeg_features, dim=-1))\n",
" # print(torch.std(img_features, dim=-1))\n",
" # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
" # loss = (regress_loss + ridge_lambda * l2_norm) \n",
" img_loss = eegmodel.loss_func(eeg_features, img_features, logit_scale)\n",
" text_loss = eegmodel.loss_func(eeg_features, text_features, logit_scale)\n",
" contrastive_loss = img_loss\n",
" # loss = img_loss + text_loss\n",
"\n",
" regress_loss = mse_loss_fn(eeg_features, img_features)\n",
" # print(\"text_loss\", text_loss)\n",
" # print(\"img_loss\", img_loss)\n",
" # print(\"regress_loss\", regress_loss) \n",
" # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
" # loss = (regress_loss + ridge_lambda * l2_norm) \n",
" loss = alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10\n",
" # print(\"loss\", loss)\n",
" total_loss += loss.item()\n",
" \n",
" for idx, label in enumerate(labels):\n",
"\n",
" possible_classes = list(all_labels - {label.item()})\n",
" selected_classes = random.sample(possible_classes, k-1) + [label.item()]\n",
" selected_img_features = img_features_all[selected_classes]\n",
" \n",
"\n",
" logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T\n",
" # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T\n",
" # logits_single = (logits_text + logits_img) / 2.0\n",
" logits_single = logits_img\n",
" # print(\"logits_single\", logits_single.shape)\n",
"\n",
" # predicted_label = selected_classes[torch.argmax(logits_single).item()]\n",
" predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \\in {0, 1, ..., n_cls-1}\n",
" if predicted_label == label.item():\n",
" correct += 1 \n",
" total += 1\n",
"\n",
" if save_features:\n",
" features_tensor = torch.cat(features_list, dim=0)\n",
" print(\"features_tensor\", features_tensor.shape)\n",
" torch.save(features_tensor.cpu(), f\"emb_eeg/{config['encoder_type']}_eeg_features_{sub}_{mode}.pt\") # Save features as .pt file\n",
" average_loss = total_loss / (batch_idx+1)\n",
" accuracy = correct / total\n",
" return average_loss, accuracy, labels, features_tensor.cpu()\n",
"\n",
"from IPython.display import Image, display\n",
"config = {\n",
"\"data_path\": \"/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz\",\n",
"\"project\": \"train_pos_img_text_rep\",\n",
"\"entity\": \"sustech_rethinkingbci\",\n",
"\"name\": \"lr=3e-4_img_pos_pro_eeg\",\n",
"\"lr\": 3e-4,\n",
"\"epochs\": 50,\n",
"\"batch_size\": 1024,\n",
"\"logger\": True,\n",
"\"encoder_type\":'EEGNetv4_Encoder',\n",
"}\n",
"\n",
"device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"data_path = config['data_path']\n",
"emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')\n",
"emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')\n",
"\n",
"eeg_model = EEGNetv4_Encoder()\n",
"print('number of parameters:', sum([p.numel() for p in eeg_model.parameters()]))\n",
"\n",
"#####################################################################################\n",
"\n",
"# eeg_model.load_state_dict(torch.load(\"/home/ldy/Workspace/Reconstruction/models/contrast/sub-08/01-30_00-44/40.pth\"))\n",
"eeg_model.load_state_dict(torch.load(\"/home/ldy/Workspace/BrainAligning_retrieval/models/contrast/EEGNetv4_Encoder/sub-08/02-10_04-58/40.pth\"))\n",
"eeg_model = eeg_model.to(device)\n",
"sub = 'sub-08'\n",
"\n",
"#####################################################################################\n",
"\n",
"test_dataset = EEGDataset(data_path, subjects= [sub], train=False)\n",
"test_loader = DataLoader(test_dataset, batch_size=config[\"batch_size\"], shuffle=False, num_workers=0)\n",
"text_features_test_all = test_dataset.text_features\n",
"img_features_test_all = test_dataset.img_features\n",
"test_loss, test_accuracy,labels, eeg_features_test = get_eegfeatures(sub, eeg_model, test_loader, device, text_features_test_all, img_features_test_all,k=200, mode=\"test\")\n",
"print(f\" - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"self.subjects ['sub-08']\n",
"exclude_subject None\n",
"data_tensor torch.Size([66160, 63, 250])\n",
"Data tensor shape: torch.Size([66160, 63, 250]), label tensor shape: torch.Size([66160]), text length: 1654, image length: 16540\n",
"features_tensor torch.Size([66160, 1024])\n",
" - Test Loss: 7.4921, Test Accuracy: 0.0050\n"
]
}
],
"source": [
"#####################################################################################\n",
"train_dataset = EEGDataset(data_path, subjects= [sub], train=True)\n",
"train_loader = DataLoader(train_dataset, batch_size=config[\"batch_size\"], shuffle=False, num_workers=0)\n",
"text_features_test_all = train_dataset.text_features\n",
"img_features_test_all = train_dataset.img_features\n",
"\n",
"train_loss, train_accuracy, labels, eeg_features_train = get_eegfeatures(sub, eeg_model, train_loader, device, text_features_test_all, img_features_test_all,k=200, mode=\"train\")\n",
"print(f\" - Test Loss: {train_loss:.4f}, Test Accuracy: {train_accuracy:.4f}\")\n",
"#####################################################################################"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"import open_clip\n",
"from matplotlib.font_manager import FontProperties\n",
"\n",
"import sys\n",
"from diffusion_prior import *\n",
"from custom_pipeline import *\n",
"# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
"device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)\n",
"emb_eeg = torch.load(f\"/home/ldy/Workspace/Generation/emb_eeg/{config['encoder_type']}_eeg_features_{sub}_train.pt\")\n",
"emb_eeg_test = torch.load(f\"/home/ldy/Workspace/Generation/emb_eeg/{config['encoder_type']}_eeg_features_{sub}_test.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([66160, 1024]), torch.Size([200, 1024]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb_eeg.shape, emb_eeg_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.6689, -0.0679, -0.0944, ..., -0.4610, -0.4619, 0.6076],\n",
" [-1.2554, 2.1959, 0.2682, ..., 0.1779, -0.9086, -0.6030],\n",
" [-0.0324, 0.0460, -0.0866, ..., -0.4524, 0.2912, -0.2215],\n",
" ...,\n",
" [-0.8543, 0.7007, 1.0704, ..., -0.1372, 0.0445, 1.1926],\n",
" [ 0.0635, -0.1275, 0.4409, ..., -0.2616, -0.4191, 1.0185],\n",
" [ 0.0479, -0.3220, 0.2826, ..., 0.3011, -0.4032, 0.4529]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eeg_features_train"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from diffusion_prior import *\n",
"class EmbeddingDataset(Dataset):\n",
"\n",
" def __init__(self, c_embeddings=None, h_embeddings=None, h_embeds_uncond=None, cond_sampling_rate=0.5):\n",
" self.c_embeddings = c_embeddings\n",
" self.h_embeddings = h_embeddings\n",
" self.N_cond = 0 if self.h_embeddings is None else len(self.h_embeddings)\n",
" self.h_embeds_uncond = h_embeds_uncond\n",
" self.N_uncond = 0 if self.h_embeds_uncond is None else len(self.h_embeds_uncond)\n",
" self.cond_sampling_rate = cond_sampling_rate\n",
"\n",
" def __len__(self):\n",
" return self.N_cond\n",
"\n",
" def __getitem__(self, idx):\n",
" return {\n",
" \"c_embedding\": self.c_embeddings[idx],\n",
" \"h_embedding\": self.h_embeddings[idx]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9675648\n",
"epoch: 0, loss: 1.1337378648611216\n",
"epoch: 1, loss: 0.9277484618700468\n",
"epoch: 2, loss: 0.7346047126329862\n",
"epoch: 3, loss: 0.5879583936471205\n",
"epoch: 4, loss: 0.48145529077603266\n",
"epoch: 5, loss: 0.4021404353471903\n",
"epoch: 6, loss: 0.34819340018125683\n",
"epoch: 7, loss: 0.3164940769855793\n",
"epoch: 8, loss: 0.29743201365837685\n",
"epoch: 9, loss: 0.2812065954391773\n",
"epoch: 10, loss: 0.2707386599137233\n",
"epoch: 11, loss: 0.25827456024976875\n",
"epoch: 12, loss: 0.24789529328162854\n",
"epoch: 13, loss: 0.23868788549533257\n",
"epoch: 14, loss: 0.2311222133728174\n",
"epoch: 15, loss: 0.22522189250359168\n",
"epoch: 16, loss: 0.21735877761474023\n",
"epoch: 17, loss: 0.21265713274478912\n",
"epoch: 18, loss: 0.21124263864297133\n",
"epoch: 19, loss: 0.20706446170806886\n",
"epoch: 20, loss: 0.20532444371626926\n",
"epoch: 21, loss: 0.2046730555020846\n",
"epoch: 22, loss: 0.20209349187520834\n",
"epoch: 23, loss: 0.20230029294124016\n",
"epoch: 24, loss: 0.19958276244310233\n",
"epoch: 25, loss: 0.1956863997074274\n",
"epoch: 26, loss: 0.19550649454960456\n",
"epoch: 27, loss: 0.1932537264548815\n",
"epoch: 28, loss: 0.1931295592051286\n",
"epoch: 29, loss: 0.19213039852105654\n",
"epoch: 30, loss: 0.19155625036129584\n",
"epoch: 31, loss: 0.18909929784444662\n",
"epoch: 32, loss: 0.18992915061803964\n",
"epoch: 33, loss: 0.188588475722533\n",
"epoch: 34, loss: 0.18476488429766436\n",
"epoch: 35, loss: 0.18678866395583518\n",
"epoch: 36, loss: 0.18525446424117456\n",
"epoch: 37, loss: 0.18486184248557458\n",
"epoch: 38, loss: 0.18248385351437788\n",
"epoch: 39, loss: 0.1811525195837021\n",
"epoch: 40, loss: 0.18144333362579346\n",
"epoch: 41, loss: 0.17867409541056706\n",
"epoch: 42, loss: 0.17740659071848944\n",
"epoch: 43, loss: 0.1764389482828287\n",
"epoch: 44, loss: 0.17700819281431346\n",
"epoch: 45, loss: 0.17582294803399307\n",
"epoch: 46, loss: 0.17429512349458842\n",
"epoch: 47, loss: 0.1736076641541261\n",
"epoch: 48, loss: 0.17225329073575826\n",
"epoch: 49, loss: 0.17144493048007672\n",
"epoch: 50, loss: 0.17105715228961063\n",
"epoch: 51, loss: 0.17152268909491025\n",
"epoch: 52, loss: 0.16880368773753826\n",
"epoch: 53, loss: 0.170586480314915\n",
"epoch: 54, loss: 0.16783449489336746\n",
"epoch: 55, loss: 0.1691400450009566\n",
"epoch: 56, loss: 0.1658798662515787\n",
"epoch: 57, loss: 0.16694954060591183\n",
"epoch: 58, loss: 0.16697401931652656\n",
"epoch: 59, loss: 0.165232030244974\n",
"epoch: 60, loss: 0.16396311521530152\n",
"epoch: 61, loss: 0.16203268628854017\n",
"epoch: 62, loss: 0.16698091740791615\n",
"epoch: 63, loss: 0.16399981792156512\n",
"epoch: 64, loss: 0.16448446810245515\n",
"epoch: 65, loss: 0.16406533626409678\n",
"epoch: 66, loss: 0.16220627793898948\n",
"epoch: 67, loss: 0.1626904058914918\n",
"epoch: 68, loss: 0.16437909328020536\n",
"epoch: 69, loss: 0.1609313971721209\n",
"epoch: 70, loss: 0.1631575719668315\n",
"epoch: 71, loss: 0.1599584792669003\n",
"epoch: 72, loss: 0.1629338324069977\n",
"epoch: 73, loss: 0.16070766242650839\n",
"epoch: 74, loss: 0.15935472937730644\n",
"epoch: 75, loss: 0.16142174509855417\n",
"epoch: 76, loss: 0.15977356502642998\n",
"epoch: 77, loss: 0.15964984893798828\n",
"epoch: 78, loss: 0.15815440118312835\n",
"epoch: 79, loss: 0.1591795611840028\n",
"epoch: 80, loss: 0.1592421217606618\n",
"epoch: 81, loss: 0.15924676289925208\n",
"epoch: 82, loss: 0.15849402982455033\n",
"epoch: 83, loss: 0.1606162871305759\n",
"epoch: 84, loss: 0.15699961208380186\n",
"epoch: 85, loss: 0.15537095528382522\n",
"epoch: 86, loss: 0.1578238546848297\n",
"epoch: 87, loss: 0.15665148473702944\n",
"epoch: 88, loss: 0.15740657655092385\n",
"epoch: 89, loss: 0.1553252440232497\n",
"epoch: 90, loss: 0.15652525012309734\n",
"epoch: 91, loss: 0.15760714686833896\n",
"epoch: 92, loss: 0.15551115549527683\n",
"epoch: 93, loss: 0.15572300988894242\n",
"epoch: 94, loss: 0.15536696429436023\n",
"epoch: 95, loss: 0.1551545427395747\n",
"epoch: 96, loss: 0.15626645042346074\n",
"epoch: 97, loss: 0.15537708309980539\n",
"epoch: 98, loss: 0.15510759628736057\n",
"epoch: 99, loss: 0.1543263880106119\n",
"epoch: 100, loss: 0.15425271919140449\n",
"epoch: 101, loss: 0.15487114764176882\n",
"epoch: 102, loss: 0.15255439717036026\n",
"epoch: 103, loss: 0.1513884425163269\n",
"epoch: 104, loss: 0.1535035007275068\n",
"epoch: 105, loss: 0.15325337877640358\n",
"epoch: 106, loss: 0.1540354270201463\n",
"epoch: 107, loss: 0.154390015739661\n",
"epoch: 108, loss: 0.15147823072396793\n",
"epoch: 109, loss: 0.15220667123794557\n",
"epoch: 110, loss: 0.1523587139753195\n",
"epoch: 111, loss: 0.15091401178103228\n",
"epoch: 112, loss: 0.15136499611230997\n",
"epoch: 113, loss: 0.1518503189086914\n",
"epoch: 114, loss: 0.1512597544835164\n",
"epoch: 115, loss: 0.14968440463909735\n",
"epoch: 116, loss: 0.1488885929951301\n",
"epoch: 117, loss: 0.15059134180729206\n",
"epoch: 118, loss: 0.1497446296306757\n",
"epoch: 119, loss: 0.14996495407361252\n",
"epoch: 120, loss: 0.15109946108781375\n",
"epoch: 121, loss: 0.14965039904300984\n",
"epoch: 122, loss: 0.1514803063410979\n",
"epoch: 123, loss: 0.14924780818132255\n",
"epoch: 124, loss: 0.14947404265403746\n",
"epoch: 125, loss: 0.14996229318472054\n",
"epoch: 126, loss: 0.1485686224240523\n",
"epoch: 127, loss: 0.14805283431823438\n",
"epoch: 128, loss: 0.14889377791147965\n",
"epoch: 129, loss: 0.14838851529818314\n",
"epoch: 130, loss: 0.1487965356845122\n",
"epoch: 131, loss: 0.14777726393479568\n",
"epoch: 132, loss: 0.14756032939140612\n",
"epoch: 133, loss: 0.14845420878667098\n",
"epoch: 134, loss: 0.1478582792557203\n",
"epoch: 135, loss: 0.1495072433581719\n",
"epoch: 136, loss: 0.14733073848944445\n",
"epoch: 137, loss: 0.15026113734795496\n",
"epoch: 138, loss: 0.14988462902032412\n",
"epoch: 139, loss: 0.14790701659826133\n",
"epoch: 140, loss: 0.14759095494563762\n",
"epoch: 141, loss: 0.14727302835537837\n",
"epoch: 142, loss: 0.14784078666797051\n",
"epoch: 143, loss: 0.14688316400234516\n",
"epoch: 144, loss: 0.1472667499230458\n",
"epoch: 145, loss: 0.14967363591377553\n",
"epoch: 146, loss: 0.14848633454396173\n",
"epoch: 147, loss: 0.14954848610437832\n",
"epoch: 148, loss: 0.14868795986358935\n",
"epoch: 149, loss: 0.1491350281697053\n"
]
}
],
"source": [
"dataset = EmbeddingDataset(\n",
" c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, \n",
" # h_embeds_uncond=h_embeds_imgnet\n",
")\n",
"dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)\n",
"diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
"# number of parameters\n",
"print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
"pipe = Pipe(diffusion_prior, device=device)\n",
"\n",
"# load pretrained model\n",
"model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
"pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Couldn't connect to the Hub: 504 Server Error: Gateway Time-out for url: https://huggingface.co/api/models/stabilityai/sdxl-turbo.\n",
"Will try to load from local cache.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1aaf9a052c4471ebf7ec72a126b1bf1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.14it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d0c00f7f085e4f8aba52e8ca7ab73633",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.56it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26734873e7734d16b8f55195b07f7f9c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.17it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a2ac18d49d004101ad78a1db658182d1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.82it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "524c591e92bb4643b7e97ba380027278",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.46it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4eec485b4a1445d6ab1704dd3b5a2179",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 258.96it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0f8666997ac948e99396ee8033b4b030",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 258.36it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dc957e6642234fa680fd376dab669222",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 258.16it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "102d173802ea46b2a3be5413d8d40069",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2644da46940a4aeaaccb796c352bce8c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.86it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1025715eaf5a42b5957ad3c305d26167",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.01it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f99bc32314b6464788a7d0d3fee31b30",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 248.91it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e22d0cc83874c178b332ad97b566c8d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.80it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f3f0755e9efe4ed09662e8b184315960",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.18it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1d29d736249c4ff7bd3f357b33bc4fd2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 239.76it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cd4dbc58c241406993fe13c238e40fda",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.15it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d280f5953cd7479f80b50c40d7a26238",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.26it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a2de32ce6a740a8963b0952cc467005",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.09it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c49ff22ff7c44034aa7c970ce8f7c047",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.06it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3de8c6e20dfe4fa4a0cdc81615b8fac7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.38it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cea8406f617a4337b9ae918671c49578",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2f26a4518c94bab980cbc4688a1821a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.36it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1cb3273977234e44b1a9a10e493d9212",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8846b53013564843b2e049f1f268cf2f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.61it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fea4a7baab4d43f98a701e08b3da2d17",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.01it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "12033088d89c4176bf6c511cc0c0008e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.45it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "78775447c3c64cca9b93f0908d2a6015",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.27it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4ec9e783ca314b428422439ef99c4751",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 123.73it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11bbbdd5ba04416f90f5103cc338c9b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.47it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "668e02034dec4bc89c27675dfa39cead",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.90it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b932fc41ab6d4aa1bda1c990bc240426",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.09it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a1e3fd60cec749f2bb7ff93cf4fc93e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.54it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0abc5e2afb5e4b0ca4ebc3f5f407818f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.09it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39f382f2d47a41e4b31cc680a69b2c39",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 218.48it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ddbdf267a9148849253ae8ad8fc18ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.98it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "612cd7d10e4e4a439e4affd23284af88",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.64it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f7d435a32cea419484979c00d321ddf2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.53it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0efd6862db544f349e75770184c4d945",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.11it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06a0d9582ec049bfa70372f42b0ad5c1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 223.37it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7864e5da375e4e66843e01b716101ca6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.06it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "87bb2c8793d648f0ac6279e9e6d974e0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.74it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "34d4b1a07b0d435bbcce16240681b97e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.42it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a73bf1b37d66432b902957dee67a8bfb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.87it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "640222362bff4c12b30bcf5c683c001a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.99it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "061600bf670840c18f42021e8ef6ae40",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.18it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a2284d0c92454048b38482c28edb7487",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f282e016aba84373ad9839d3207c3e77",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.59it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "04239c2d4daf4d9f86977fd5673a0351",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 226.57it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "41cda47a708d46fa9ed02fdf623ff43f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "13a8faf4f8454f989fdcfde69191ff63",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.57it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3f887620668d415696a2f8dfc672a1ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.04it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "972e007ea00647259ff5dd84eb2955c3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 259.62it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c979262491d1416f82caf6423d88e6b7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.70it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b92b630c3ae4720b6aadecd8be1438f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 259.87it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6596ab5dcd5e4d1980b3d42100ed110e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.52it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d9d122edad34913b515515c105752ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.31it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f74d8a458c68495a9f21ad2b5823dfbd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.00it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8af5ddadd5224f18be6cd010bea6026f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 226.85it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7d315a57018d4afab5c4d2d955a88235",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 130.55it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3aab120d839c4fe3ad81242f55b2b64a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.11it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e14bbe70bbd64be7b85a6a82315ed968",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.40it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "508c1a13ed954d068eebea94116b622f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "816d6d467d1744d8aaa5c0313a474cc2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.19it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ca6aa87e560644e58c7e4384e97aa250",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.38it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "384c7075853d4c86a9ae3395c3b61532",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.54it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "064991500ef0444b89a866c394b1e310",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.72it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9471f48bacd546499df04a0a31993045",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.47it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4592853e833148f78a9db0006522c7a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.01it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1c7e123a75a0448097af4e627507940d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.31it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bb37fe1a823d42b8969d1fcb5553244a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.98it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5df90e4358c746a4ac82bb2db1e1ecc4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c9d5e9b765d141d7bb0a60a92c0cc21d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.24it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cd92d5bc794c44cd8058e036e38e2de0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.98it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4d92f9edb7824521b5b7d455a8f0d2dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.77it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d1b31e3f863f48d285e7ddd6f3608e16",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.53it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "733ee7b0f5264d8cbb1318868898feb0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.86it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2aa6fbe0431c489391fbee2b92af7a33",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.47it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f3d979573a334a55816950a665938146",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.59it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3c2b16d512374c328f4492943489d91e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.97it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "301244c529214c1d9ad3ca77458a7520",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.91it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4fb38fd149fc4c0f944aa79dadaff69f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 214.04it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2438f7af8f064618ad6086d3b2a98a96",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.62it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ffc728de2403431fb06128c0378fa165",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.41it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42fb0f1853d94385a11d4e11777aae08",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.21it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "372ac69b4ed443d3ab0b109c432f338b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.56it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "de1dc7080f804ff8b04fd2794aab888f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.51it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c903dc3379264309a96a0dc545fc604c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.66it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9bcec69fca324dfdbb71f865bc03fb1c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.23it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "da29d96a8330453d9a155aeb30b6f072",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.90it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "293dcbd077d34c28a90165d57b5cf7e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.91it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d3e9e8da58ad4ff38db1da7f8c234f79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 124.99it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "354cd4a14f4f48faabd6ff34fa4c7877",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.79it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "54a1886b87cb4dbdb60c46f3f373e7b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ff6ba27830114e8195ee72a172896a03",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.63it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "389c2c4222cb4957934dfb7ecefafd08",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d28ce92b68094b53ae53dd675456388c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.43it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "04011f4b4a01429a82e20f3537a6fab0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.96it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "24053569fe5b4804aeb2ee6d965e908c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.10it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f69acb99d5444865be1fc5e7103e6575",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.43it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62c11fd9a6994b6797b7a4ef264ca205",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 234.28it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "681c96c3d6fa4f0fa3f221a5897c2144",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.68it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d5f2aeb26634132bda722a134f57706",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.05it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d29488a0421a4a25af4533e2f571a536",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.42it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "718ad7e8592745248a1380393e6c7c08",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.72it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9c58a8d6ef7d430dbc1a66031b0be48a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.30it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0563678c2938436896b3af6ea69de3a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 242.30it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "09a0f78acf254d808cc09f6b31418294",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.39it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bb9a16d5a4844d2c95606be2634d2b6e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.15it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7a059b1d5bb43ec946ee1a69ed5fb83",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.58it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bcb454d0fe7a46ffbd0f7ba5949d49ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9fec2f69964c455096c23d22f5d4e94d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.50it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c7677bdb2d9947cb9dc7036916497b9b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.62it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ddd884ce1a8470681d30866e2997ec8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.87it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3456b0c4e204423590a924ad4710a79a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.67it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "de1c61e65ef244a9bd10071644705d89",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.36it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "870ebbd0b3794fab837ab0a66dd4f370",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f6bac693aef24b09b96b8192a66af81b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 257.44it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "02039cda5b654f268a2339497f27e5d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.24it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8574db4bd3254dbdbfed1a5fdcfce1cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.19it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "731aa7be7d2645ce81947aac8d78c9a0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.96it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "943e28c6fbc24d72bfbcf49ddda096c6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.90it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7874ca03894f43ebb0b054646a1073c9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.93it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "63ef489e12c341f5a5a3f269b013921f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.13it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9aaf7ab4324643f89759428635679194",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 125.50it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6b2baedfd5e643289a85bf7b87f93a1a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.74it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0ad2f10799e4bc4a6278fb4829d32e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.48it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a2d9906ce7d64265acb3d085d53a7bc3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c8ff40dc362d49b8904c8a7cbd4fe327",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.07it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f4fff3a6f7434685aa2168c7901d2c8f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.47it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a145b0bdb29c457bb7def0fbfee36f4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.40it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6afd40c9d894e98bdf09abcbbc65a22",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.78it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28c70ec5368c403c906e6e123c846b38",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.51it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "27871ea93e7d47288129e4de65eec209",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.82it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8e9c3af3d77d4e839ddf8abb94e18b42",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.53it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "508fc712453f4da899970def561aedb0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.85it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9bf9c188086947d4bda8e51133052280",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 253.69it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "471bab1e3c074bddbad81aa0fc91b5cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.59it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e79f999ac78d4e03903472b0d95fe7ee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 258.85it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "046d3b3093b84eb88a0626f014edecfd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.13it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38ef3db69e7845f1abc3ec71e58766e6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 259.89it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1f11890f9f2740d593bd481fccfe4df1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.94it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a908088b3520415d9b5bd9840fce05ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 255.92it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6dda43fded55474594056d041f53d95c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.52it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4edeff4eeba641309e323a4ce8ba859a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 220.42it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a0ec3321b6194d20b8bf96930cd90925",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.62it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3794c3c53e1741e5ac1558b7863a7b1b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.13it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb0370b13fbe4be39258c21c6d35e5c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 258.30it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2443cc67e32948439a38224dff7e403e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.07it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1c1ca8603a9455dbabdaeb7731db2d2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 228.27it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ad776510ca3d4a10b8d5d27009a08ebf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.18it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ec1985d634f341fa80bbba16dbc7c26f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 246.50it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "922ae77306564dc8a19afdfc91ce0549",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.83it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "11868bb8f3dc418f907b6969421a818c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.69it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "695cecc50a0a42b5b5046163e36fe6a8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 254.24it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e77bf8978e2445acaa32ce8b3cd1ab96",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.01it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5e40ad0357e435e95980e4ad54d9094",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 216.19it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a737380d37e5455dbfca62dc8991bba0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 118.74it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2667af5b67df4ea6892ba017cf50f644",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.83it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f7cb3cba7d9e4bb7914275d05338fef8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 231.89it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a2d81f798da340838b76b103bce453a8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d87d418e836c4caa979b436c49a39427",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.14it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "191bbdf9b55e4bc0814627362a4ee99f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.21it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d43db3c39174cb5bb9b786a0c556a52",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 171.81it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "408f86d3fa44467195db8e49e3eebea7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "812dcf1595de4bb0a8986f6fe494fce5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 232.56it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "797f33ad8e954f79b399b63f2d22012d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 227.28it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3692501e6a9e4fe08fe7baa9620c4f3a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.84it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d61edd54fbc4f00a9bb72e6c420b3bb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 248.06it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "114145f6297c4de18008358dfcf2d92d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.98it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e8052c27504f405ba51d2bec21a137d8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 241.52it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38f323d606ad4402b934470031a5461c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.52it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47dc0e3c60a1449d9331c06163fa179c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 217.67it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "840fc006ac334a58b6fc0995c4ae46d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 245.98it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5c62c770b4a04f50b015914fc6f52256",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 244.74it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "41afacab3c8c44388fe511c819a81e75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 237.88it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ebcb9680cdc4cb69f7a53fc896bb1b8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.60it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3400aff17a1449e08a7e3f450dd7d718",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 239.31it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "98341b0083b443fca0a6423bf316fcce",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 249.34it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b696a4bda3a44fa98e3d6aad72738a8c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.36it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c1f58bc3f4b474aba16c6b2caa14d0b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 240.47it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f0d03c151a341d1b4df30af3a055ceb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 243.49it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "658821db72c34523b2c84f4f482c6b5c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 202.36it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "41d5d7ef548c4abe886e8858835c6c87",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 233.53it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "96459e52d7c94f5f959e19a422aba39c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 236.42it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "09930788dde24d34b7bb33b7f009938b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 222.06it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "27c8c8135bc340faa08828647d94b712",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.76it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "43b060732b9e4052956329496851a75c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 248.66it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c95feb87b2e487d9dc4dd06d9163a2d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 247.35it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c3ebc636a757473dbb86d1a5b34abce9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 116.90it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "be73a804ce4948aebe485974d7dd0ffc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 230.68it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c7fbe3d42d54fa6ae47d34e12e7afc2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 238.50it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7164115353ea45cabda1a3f5066ece43",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 250.14it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6a62983b2834efea26f22ba179e1f95",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 248.30it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "037446d117af4408bc2567a995950158",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.84it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a95cb244d64046d79ae98a5209152555",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 256.32it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "88bb9569d4e344b1940184b9804668f8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 225.10it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1d52ee5bb1b14add88d3be62732bec13",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.20it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "530a2fc8a9fe49269a4c50a2acbc8798",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 229.40it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c766ba2ae696417e8aa0ff576d9ff8bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 252.02it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "14c3974747ac476db750a70870d798be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"50it [00:00, 251.67it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c95435e3fc24443ebcceed21d84d5214",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"pipe.diffusion_prior.load_state_dict(torch.load(f\"./fintune_ckpts/{config['encoder_type']}/{sub}/{model_name}.pt\", map_location=device))\n",
"# save_path = f'./fintune_ckpts/{config[\"encoder_type\"]}/{sub}/{model_name}.pt'\n",
"\n",
"# directory = os.path.dirname(save_path)\n",
"\n",
"# Create the directory if it doesn't exist\n",
"# os.makedirs(directory, exist_ok=True)\n",
"# torch.save(pipe.diffusion_prior.state_dict(), save_path)\n",
"from PIL import Image\n",
"import os\n",
"\n",
"# Assuming generator.generate returns a PIL Image\n",
"generator = Generator4Embeds(num_inference_steps=4, device=device)\n",
"\n",
"directory = f\"Generation/{config['encoder_type']}/generated_imgs/{sub}\"\n",
"for k in range(200):\n",
" eeg_embeds = emb_eeg_test[k:k+1]\n",
" h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)\n",
" for j in range(1):\n",
" image = generator.generate(h.to(dtype=torch.float16))\n",
" # Construct the save path for each image\n",
" path = f'{directory}/{texts[k]}/{j}.png'\n",
" # Ensure the directory exists\n",
" os.makedirs(os.path.dirname(path), exist_ok=True)\n",
" # Save the PIL Image\n",
" image.save(path)\n",
" # print(f'Image saved to {path}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import os\n",
"# # os.environ['http_proxy'] = 'http://10.16.35.10:13390' \n",
"# # os.environ['https_proxy'] = 'http://10.16.35.10:13390' \n",
"# # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
"# # os.environ['PATH'] += os.pathsep + '/usr/local/texlive/2023/bin/x86_64-linux'\n",
"\n",
"# import torch\n",
"# from torch import nn\n",
"# import torch.nn.functional as F\n",
"# import torchvision.transforms as transforms\n",
"# import matplotlib.pyplot as plt\n",
"# import open_clip\n",
"# from matplotlib.font_manager import FontProperties\n",
"\n",
"# import sys\n",
"# from diffusion_prior import *\n",
"# from custom_pipeline import *\n",
"\n",
"# device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load eeg and image embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # load image embeddings\n",
"# data = torch.load('/home/ldy/Workspace/THINGS/CLIP/ViT-H-14_features_test.pt', map_location='cuda:3')\n",
"# emb_img_test = data['img_features']\n",
"\n",
"# # load image embeddings\n",
"# data = torch.load('/home/ldy/Workspace/THINGS/CLIP/ViT-H-14_features_train.pt', map_location='cuda:3')\n",
"# emb_img_train = data['img_features']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# emb_img_test = torch.load('variables/ViT-H-14_features_test.pt')\n",
"# emb_img_train = torch.load('variables/ViT-H-14_features_train.pt')\n",
"\n",
"# torch.save(emb_img_test.cpu().detach(), 'variables/ViT-H-14_features_test.pt')\n",
"# torch.save(emb_img_train.cpu().detach(), 'variables/ViT-H-14_features_train.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# emb_img_test.shape, emb_img_train.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1654clsx10imgsx4trials=66160\n",
"# emb_eeg = torch.load('/home/ldy/Workspace/Reconstruction/ATM_S_eeg_features_sub-08.pt')\n",
"\n",
"# emb_eeg_test = torch.load('/home/ldy/Workspace/Reconstruction/ATM_S_eeg_features_sub-08_test.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# emb_eeg.shape, emb_eeg_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training prior diffusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class EmbeddingDataset(Dataset):\n",
"\n",
" def __init__(self, c_embeddings=None, h_embeddings=None, h_embeds_uncond=None, cond_sampling_rate=0.5):\n",
" self.c_embeddings = c_embeddings\n",
" self.h_embeddings = h_embeddings\n",
" self.N_cond = 0 if self.h_embeddings is None else len(self.h_embeddings)\n",
" self.h_embeds_uncond = h_embeds_uncond\n",
" self.N_uncond = 0 if self.h_embeds_uncond is None else len(self.h_embeds_uncond)\n",
" self.cond_sampling_rate = cond_sampling_rate\n",
"\n",
" def __len__(self):\n",
" return self.N_cond\n",
"\n",
" def __getitem__(self, idx):\n",
" return {\n",
" \"c_embedding\": self.c_embeddings[idx],\n",
" \"h_embedding\": self.h_embeddings[idx]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"emb_img_train_4 = emb_img_train.view(1654,10,1,1024).repeat(1,1,4,1).view(-1,1024)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"emb_img_train_4.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# path_data = '/mnt/dataset0/weichen/projects/visobj/proposals/mise/data'\n",
"# image_features = torch.load(os.path.join(path_data, 'openclip_emb/emb_imgnet.pt')) # 'emb_imgnet' or 'image_features'\n",
"# h_embeds_imgnet = image_features['image_features']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"dataset = EmbeddingDataset(\n",
" c_embeddings=emb_eeg, h_embeddings=emb_img_train_4, \n",
" # h_embeds_uncond=h_embeds_imgnet\n",
")\n",
"print(len(dataset))\n",
"dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# diffusion_prior = DiffusionPrior(dropout=0.1)\n",
"diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
"# number of parameters\n",
"print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
"pipe = Pipe(diffusion_prior, device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load pretrained model\n",
"model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
"pipe.diffusion_prior.load_state_dict(torch.load(f\"./fintune_ckpts/{config['encoder_type']}/{sub}/{model_name}.pt\", map_location=device))\n",
"# pipe.train(dataloader, num_epochs=150, learning_rate=1e-3) # to 0.142 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save model\n",
"# torch.save(pipe.diffusion_prior.state_dict(), f'./ckpts/{model_name}.pt')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generating by eeg embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save model\n",
"# torch.save(pipe.diffusion_prior.state_dict(), f'./ckpts/{model_name}.pt')\n",
"from IPython.display import Image, display\n",
"generator = Generator4Embeds(num_inference_steps=4, device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# path of ground truth: /home/ldy/Workspace/THINGS/images_set/test_images\n",
"k = 99\n",
"image_embeds = emb_img_test[k:k+1]\n",
"print(\"image_embeds\", image_embeds.shape)\n",
"image = generator.generate(image_embeds)\n",
"display(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# image_embeds = emb_eeg_test[k:k+1]\n",
"# print(\"image_embeds\", image_embeds.shape)\n",
"# image = generator.generate(image_embeds)\n",
"# display(image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generating by eeg informed image embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# k = 0\n",
"eeg_embeds = emb_eeg_test[k:k+1]\n",
"print(\"image_embeds\", eeg_embeds.shape)\n",
"h = pipe.generate(c_embeds=eeg_embeds, num_inference_steps=50, guidance_scale=5.0)\n",
"image = generator.generate(h.to(dtype=torch.float16))\n",
"display(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "BCI",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}