617 lines (616 with data), 21.6 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import AutoProcessor\n",
"from modeling_git import GitForCausalLM, GitModel, GitForCausalLMClipEmb\n",
"from PIL import Image\n",
"import torch\n",
"import requests\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"GitForCausalLMClipEmb(\n",
" (git): GitModelClipEmb(\n",
" (embeddings): GitEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(1024, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (image_encoder): GitVisionModel(\n",
" (vision_model): GitVisionTransformer(\n",
" (embeddings): GitVisionEmbeddings(\n",
" (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
" (position_embedding): Embedding(257, 1024)\n",
" )\n",
" (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (encoder): GitVisionEncoder(\n",
" (layers): ModuleList(\n",
" (0-23): 24 x GitVisionEncoderLayer(\n",
" (self_attn): GitVisionAttention(\n",
" (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
" )\n",
" (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GitVisionMLP(\n",
" (activation_fn): QuickGELUActivation()\n",
" (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
" (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
" )\n",
" (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (encoder): GitEncoder(\n",
" (layer): ModuleList(\n",
" (0-5): 6 x GitLayer(\n",
" (attention): GitAttention(\n",
" (self): GitSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): GitSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): GitIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): GitOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (visual_projection): GitProjection(\n",
" (visual_projection): Sequential(\n",
" (0): Linear(in_features=1024, out_features=768, bias=True)\n",
" (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
" (output): Linear(in_features=768, out_features=30522, bias=True)\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoProcessor, AutoModelForCausalLM\n",
"from modeling_git import GitForCausalLMClipEmb\n",
"processor = AutoProcessor.from_pretrained(\"microsoft/git-large-coco\")\n",
"clip_text_model = GitForCausalLMClipEmb.from_pretrained(\"microsoft/git-large-coco\")\n",
"clip_text_model.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4\n",
"clip_text_model.eval().requires_grad_(False)\n",
"\n",
"# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
"# image = Image.open(requests.get(url, stream=True).raw)\n",
"\n",
"# git_image = Image.open(\"/root/autodl-tmp/Workspace/EEG_caption/docs/test/banana_09s.jpg\")\n",
"# pixel_values = processor(images=git_image, return_tensors=\"pt\").pixel_values.to(device)\n",
"# vision_encoder=model.git.image_encoder\n",
"\n",
"# git_image_features=vision_encoder(pixel_values).last_hidden_state.cpu()\n",
"# git_image_features.shape\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([200, 257, 1024])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"git_test = torch.load(\"/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_test.pt\")['img_features']\n",
"git_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([200, 1, 1024])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb_eeg_test = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ATM_S_eeg_features_sub-08_test.pt').unsqueeze(1)\n",
"emb_eeg_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# image_features.std()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"import open_clip\n",
"from matplotlib.font_manager import FontProperties\n",
"\n",
"import sys\n",
"from diffusion_prior import *\n",
"from custom_pipeline import *\n",
"# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"from einops.layers.torch import Rearrange, Reduce\n",
"\n",
"# Define the neural network\n",
"class PixelProjector(nn.Sequential):\n",
" def __init__(self, proj_dim=1024):\n",
" super().__init__(\n",
" Rearrange('B C L->B L C'), \n",
" nn.Linear(1, 257),\n",
" nn.LayerNorm(257),\n",
" Rearrange('B L C->B C L'),\n",
" nn.Linear(1024, 1024),\n",
" nn.LayerNorm(proj_dim),\n",
" )\n",
"model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
"model.load_state_dict(torch.load('/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin'))\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9675648\n"
]
},
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
"# number of parameters\n",
"print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
"pipe = Pipe(diffusion_prior, device=device)\n",
"\n",
"# load pretrained model\n",
"model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
"# pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 \n",
"pipe.diffusion_prior.load_state_dict(torch.load(f'{model_name}.pt', map_location=device))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"4it [00:00, 11.67it/s]\n",
"4it [00:00, 353.27it/s]\n",
"4it [00:00, 365.84it/s]\n",
"4it [00:00, 374.32it/s]\n",
"4it [00:00, 364.18it/s]\n",
"4it [00:00, 369.63it/s]\n",
"4it [00:00, 366.92it/s]\n",
"4it [00:00, 366.76it/s]\n",
"4it [00:00, 361.30it/s]\n",
"4it [00:00, 367.80it/s]\n",
"4it [00:00, 364.50it/s]\n",
"4it [00:00, 345.36it/s]\n",
"4it [00:00, 360.85it/s]\n",
"4it [00:00, 359.27it/s]\n",
"4it [00:00, 367.69it/s]\n",
"4it [00:00, 359.10it/s]\n",
"4it [00:00, 360.75it/s]\n",
"4it [00:00, 363.47it/s]\n",
"4it [00:00, 361.98it/s]\n",
"4it [00:00, 357.24it/s]\n",
"4it [00:00, 363.58it/s]\n",
"4it [00:00, 361.63it/s]\n",
"4it [00:00, 359.38it/s]\n",
"4it [00:00, 360.46it/s]\n",
"4it [00:00, 355.67it/s]\n",
"4it [00:00, 356.56it/s]\n",
"4it [00:00, 358.51it/s]\n",
"4it [00:00, 360.47it/s]\n",
"4it [00:00, 359.57it/s]\n",
"4it [00:00, 358.81it/s]\n",
"4it [00:00, 360.69it/s]\n",
"4it [00:00, 360.58it/s]\n",
"4it [00:00, 357.33it/s]\n",
"4it [00:00, 358.23it/s]\n",
"4it [00:00, 359.55it/s]\n",
"4it [00:00, 360.18it/s]\n",
"4it [00:00, 352.28it/s]\n",
"4it [00:00, 356.47it/s]\n",
"4it [00:00, 289.61it/s]\n",
"4it [00:00, 344.86it/s]\n",
"4it [00:00, 248.81it/s]\n",
"4it [00:00, 249.06it/s]\n",
"4it [00:00, 336.66it/s]\n",
"4it [00:00, 353.12it/s]\n",
"4it [00:00, 350.25it/s]\n",
"4it [00:00, 351.55it/s]\n",
"4it [00:00, 360.83it/s]\n",
"4it [00:00, 355.59it/s]\n",
"4it [00:00, 355.04it/s]\n",
"4it [00:00, 354.31it/s]\n",
"4it [00:00, 350.29it/s]\n",
"4it [00:00, 356.26it/s]\n",
"4it [00:00, 353.47it/s]\n",
"4it [00:00, 355.24it/s]\n",
"4it [00:00, 352.88it/s]\n",
"4it [00:00, 352.49it/s]\n",
"4it [00:00, 356.39it/s]\n",
"4it [00:00, 359.04it/s]\n",
"4it [00:00, 355.63it/s]\n",
"4it [00:00, 357.08it/s]\n",
"4it [00:00, 349.72it/s]\n",
"4it [00:00, 355.16it/s]\n",
"4it [00:00, 353.97it/s]\n",
"4it [00:00, 285.44it/s]\n",
"4it [00:00, 356.64it/s]\n",
"4it [00:00, 351.78it/s]\n",
"4it [00:00, 357.85it/s]\n",
"4it [00:00, 355.42it/s]\n",
"4it [00:00, 352.28it/s]\n",
"4it [00:00, 355.04it/s]\n",
"4it [00:00, 350.69it/s]\n",
"4it [00:00, 349.34it/s]\n",
"4it [00:00, 351.92it/s]\n",
"4it [00:00, 352.23it/s]\n",
"4it [00:00, 349.14it/s]\n",
"4it [00:00, 357.88it/s]\n",
"4it [00:00, 354.36it/s]\n",
"4it [00:00, 356.30it/s]\n",
"4it [00:00, 353.04it/s]\n",
"4it [00:00, 350.11it/s]\n",
"4it [00:00, 353.29it/s]\n",
"4it [00:00, 355.66it/s]\n",
"4it [00:00, 352.93it/s]\n",
"4it [00:00, 352.06it/s]\n",
"4it [00:00, 353.79it/s]\n",
"4it [00:00, 350.02it/s]\n",
"4it [00:00, 353.12it/s]\n",
"4it [00:00, 352.60it/s]\n",
"4it [00:00, 350.71it/s]\n",
"4it [00:00, 351.19it/s]\n",
"4it [00:00, 349.69it/s]\n",
"4it [00:00, 354.92it/s]\n",
"4it [00:00, 353.15it/s]\n",
"4it [00:00, 352.40it/s]\n",
"4it [00:00, 353.79it/s]\n",
"4it [00:00, 353.99it/s]\n",
"4it [00:00, 346.54it/s]\n",
"4it [00:00, 353.18it/s]\n",
"4it [00:00, 351.55it/s]\n",
"4it [00:00, 359.30it/s]\n",
"4it [00:00, 351.62it/s]\n",
"4it [00:00, 354.21it/s]\n",
"4it [00:00, 354.26it/s]\n",
"4it [00:00, 353.37it/s]\n",
"4it [00:00, 353.48it/s]\n",
"4it [00:00, 353.62it/s]\n",
"4it [00:00, 352.71it/s]\n",
"4it [00:00, 352.70it/s]\n",
"4it [00:00, 348.50it/s]\n",
"4it [00:00, 321.42it/s]\n",
"4it [00:00, 345.93it/s]\n",
"4it [00:00, 348.08it/s]\n",
"4it [00:00, 351.94it/s]\n",
"4it [00:00, 351.76it/s]\n",
"4it [00:00, 352.82it/s]\n",
"4it [00:00, 357.37it/s]\n",
"4it [00:00, 352.74it/s]\n",
"4it [00:00, 354.33it/s]\n",
"4it [00:00, 350.48it/s]\n",
"4it [00:00, 348.21it/s]\n",
"4it [00:00, 348.38it/s]\n",
"4it [00:00, 354.03it/s]\n",
"4it [00:00, 344.58it/s]\n",
"4it [00:00, 346.82it/s]\n",
"4it [00:00, 350.94it/s]\n",
"4it [00:00, 352.11it/s]\n",
"4it [00:00, 347.05it/s]\n",
"4it [00:00, 341.68it/s]\n",
"4it [00:00, 342.75it/s]\n",
"4it [00:00, 348.68it/s]\n",
"4it [00:00, 343.13it/s]\n",
"4it [00:00, 343.06it/s]\n",
"4it [00:00, 343.35it/s]\n",
"4it [00:00, 347.40it/s]\n",
"4it [00:00, 350.15it/s]\n",
"4it [00:00, 351.54it/s]\n",
"4it [00:00, 350.18it/s]\n",
"4it [00:00, 350.75it/s]\n",
"4it [00:00, 347.51it/s]\n",
"4it [00:00, 348.60it/s]\n",
"4it [00:00, 349.85it/s]\n",
"4it [00:00, 349.02it/s]\n",
"4it [00:00, 346.72it/s]\n",
"4it [00:00, 348.99it/s]\n",
"4it [00:00, 347.12it/s]\n",
"4it [00:00, 347.33it/s]\n",
"4it [00:00, 343.94it/s]\n",
"4it [00:00, 351.80it/s]\n",
"4it [00:00, 339.51it/s]\n",
"4it [00:00, 345.77it/s]\n",
"4it [00:00, 320.71it/s]\n",
"4it [00:00, 347.12it/s]\n",
"4it [00:00, 299.13it/s]\n",
"4it [00:00, 342.32it/s]\n",
"4it [00:00, 341.15it/s]\n",
"4it [00:00, 343.63it/s]\n",
"4it [00:00, 332.51it/s]\n",
"4it [00:00, 334.89it/s]\n",
"4it [00:00, 332.36it/s]\n",
"4it [00:00, 319.50it/s]\n",
"4it [00:00, 334.13it/s]\n",
"4it [00:00, 346.07it/s]\n",
"4it [00:00, 347.97it/s]\n",
"4it [00:00, 354.15it/s]\n",
"4it [00:00, 349.47it/s]\n",
"4it [00:00, 353.73it/s]\n",
"4it [00:00, 349.39it/s]\n",
"4it [00:00, 350.40it/s]\n",
"4it [00:00, 343.61it/s]\n",
"4it [00:00, 348.56it/s]\n",
"4it [00:00, 343.45it/s]\n",
"4it [00:00, 327.54it/s]\n",
"4it [00:00, 338.26it/s]\n",
"4it [00:00, 345.96it/s]\n",
"4it [00:00, 349.45it/s]\n",
"4it [00:00, 350.58it/s]\n",
"4it [00:00, 349.25it/s]\n",
"4it [00:00, 350.17it/s]\n",
"4it [00:00, 348.97it/s]\n",
"4it [00:00, 348.10it/s]\n",
"4it [00:00, 344.23it/s]\n",
"4it [00:00, 316.63it/s]\n",
"4it [00:00, 331.28it/s]\n",
"4it [00:00, 338.91it/s]\n",
"4it [00:00, 340.05it/s]\n",
"4it [00:00, 329.27it/s]\n",
"4it [00:00, 323.06it/s]\n",
"4it [00:00, 327.69it/s]\n",
"4it [00:00, 324.32it/s]\n",
"4it [00:00, 341.94it/s]\n",
"4it [00:00, 346.87it/s]\n",
"4it [00:00, 350.94it/s]\n",
"4it [00:00, 348.59it/s]\n",
"4it [00:00, 349.61it/s]\n",
"4it [00:00, 351.08it/s]\n",
"4it [00:00, 345.71it/s]\n",
"4it [00:00, 349.37it/s]\n",
"4it [00:00, 342.29it/s]\n",
"4it [00:00, 350.42it/s]\n",
"4it [00:00, 350.13it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Captions generated and saved to 'generated_captions.txt'.\n"
]
}
],
"source": [
"# Loop through each element in emb_eeg_test and generate captions\n",
"with open('semantic_level_caption.txt', 'w') as f:\n",
" for emb in emb_eeg_test:\n",
" # Generate h for each emb\n",
" h = pipe.generate(c_embeds=emb, num_inference_steps=4, guidance_scale=5.0)\n",
"\n",
" # Get test image\n",
" test_img_257_1024 = model(h.unsqueeze(0).to(torch.bfloat16).to(device))\n",
"\n",
" # Generate captions\n",
" generated_ids = clip_text_model.generate(pixel_values=test_img_257_1024.to(device).float(), max_new_tokens=25)\n",
" generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
" print(generated_caption)\n",
" # Write each caption to a new line in the txt file\n",
" f.write(f\"{generated_caption[0]}\\n\")\n",
"\n",
"print(\"Captions generated and saved to 'generated_captions.txt'.\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# dataset = EmbeddingDataset(\n",
"# c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, \n",
"# # h_embeds_uncond=h_embeds_imgnet\n",
"# )\n",
"# dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)\n",
"diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
"# number of parameters\n",
"print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
"pipe = Pipe(diffusion_prior, device=device)\n",
"\n",
"# load pretrained model\n",
"model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
"# pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 \n",
"pipe.diffusion_prior.load_state_dict(torch.load(f'{model_name}.pt', map_location=device))\n",
"\n",
"h = pipe.generate(c_embeds=emb_eeg_test[0], num_inference_steps=4, guidance_scale=5.0)\n",
"\n",
"\n",
"model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
"model.load_state_dict(torch.load('/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin'))\n",
"\n",
"\n",
"test_img_257_1024 = model(h.unsqueeze(0).to(torch.bfloat16).to(device))\n",
"\n",
"generated_ids = clip_text_model.generate(pixel_values=test_img_257_1024.to(device).float(), max_new_tokens=25)\n",
"\n",
"generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
"generated_caption"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
"# generated_caption"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "BCI",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.undefined"
}
},
"nbformat": 4,
"nbformat_minor": 2
}