[bb64db]: / Generation / GIT_caption_batch.ipynb

Download this file

617 lines (616 with data), 21.6 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoProcessor\n",
    "from modeling_git import GitForCausalLM, GitModel, GitForCausalLMClipEmb\n",
    "from PIL import Image\n",
    "import torch\n",
    "import requests\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GitForCausalLMClipEmb(\n",
       "  (git): GitModelClipEmb(\n",
       "    (embeddings): GitEmbeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(1024, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (image_encoder): GitVisionModel(\n",
       "      (vision_model): GitVisionTransformer(\n",
       "        (embeddings): GitVisionEmbeddings(\n",
       "          (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
       "          (position_embedding): Embedding(257, 1024)\n",
       "        )\n",
       "        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "        (encoder): GitVisionEncoder(\n",
       "          (layers): ModuleList(\n",
       "            (0-23): 24 x GitVisionEncoderLayer(\n",
       "              (self_attn): GitVisionAttention(\n",
       "                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              )\n",
       "              (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "              (mlp): GitVisionMLP(\n",
       "                (activation_fn): QuickGELUActivation()\n",
       "                (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "                (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "              )\n",
       "              (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (encoder): GitEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-5): 6 x GitLayer(\n",
       "          (attention): GitAttention(\n",
       "            (self): GitSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): GitSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): GitIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): GitOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (visual_projection): GitProjection(\n",
       "      (visual_projection): Sequential(\n",
       "        (0): Linear(in_features=1024, out_features=768, bias=True)\n",
       "        (1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (output): Linear(in_features=768, out_features=30522, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoProcessor, AutoModelForCausalLM\n",
    "from modeling_git import GitForCausalLMClipEmb\n",
    "processor = AutoProcessor.from_pretrained(\"microsoft/git-large-coco\")\n",
    "clip_text_model = GitForCausalLMClipEmb.from_pretrained(\"microsoft/git-large-coco\")\n",
    "clip_text_model.to(device) # if you get OOM running this script, you can switch this to cpu and lower minibatch_size to 4\n",
    "clip_text_model.eval().requires_grad_(False)\n",
    "\n",
    "# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
    "# image = Image.open(requests.get(url, stream=True).raw)\n",
    "\n",
    "# git_image = Image.open(\"/root/autodl-tmp/Workspace/EEG_caption/docs/test/banana_09s.jpg\")\n",
    "# pixel_values = processor(images=git_image, return_tensors=\"pt\").pixel_values.to(device)\n",
    "# vision_encoder=model.git.image_encoder\n",
    "\n",
    "# git_image_features=vision_encoder(pixel_values).last_hidden_state.cpu()\n",
    "# git_image_features.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([200, 257, 1024])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "git_test = torch.load(\"/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_test.pt\")['img_features']\n",
    "git_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([200, 1, 1024])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb_eeg_test = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ATM_S_eeg_features_sub-08_test.pt').unsqueeze(1)\n",
    "emb_eeg_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image_features.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import open_clip\n",
    "from matplotlib.font_manager import FontProperties\n",
    "\n",
    "import sys\n",
    "from diffusion_prior import *\n",
    "from custom_pipeline import *\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\" \n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from einops.layers.torch import Rearrange, Reduce\n",
    "\n",
    "# Define the neural network\n",
    "class PixelProjector(nn.Sequential):\n",
    "    def __init__(self, proj_dim=1024):\n",
    "        super().__init__(\n",
    "            Rearrange('B C L->B L C'),    \n",
    "            nn.Linear(1, 257),\n",
    "            nn.LayerNorm(257),\n",
    "            Rearrange('B L C->B C L'),\n",
    "            nn.Linear(1024, 1024),\n",
    "            nn.LayerNorm(proj_dim),\n",
    "            )\n",
    "model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
    "model.load_state_dict(torch.load('/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9675648\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
    "# number of parameters\n",
    "print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
    "pipe = Pipe(diffusion_prior, device=device)\n",
    "\n",
    "# load pretrained model\n",
    "model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
    "# pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 \n",
    "pipe.diffusion_prior.load_state_dict(torch.load(f'{model_name}.pt', map_location=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4it [00:00, 11.67it/s]\n",
      "4it [00:00, 353.27it/s]\n",
      "4it [00:00, 365.84it/s]\n",
      "4it [00:00, 374.32it/s]\n",
      "4it [00:00, 364.18it/s]\n",
      "4it [00:00, 369.63it/s]\n",
      "4it [00:00, 366.92it/s]\n",
      "4it [00:00, 366.76it/s]\n",
      "4it [00:00, 361.30it/s]\n",
      "4it [00:00, 367.80it/s]\n",
      "4it [00:00, 364.50it/s]\n",
      "4it [00:00, 345.36it/s]\n",
      "4it [00:00, 360.85it/s]\n",
      "4it [00:00, 359.27it/s]\n",
      "4it [00:00, 367.69it/s]\n",
      "4it [00:00, 359.10it/s]\n",
      "4it [00:00, 360.75it/s]\n",
      "4it [00:00, 363.47it/s]\n",
      "4it [00:00, 361.98it/s]\n",
      "4it [00:00, 357.24it/s]\n",
      "4it [00:00, 363.58it/s]\n",
      "4it [00:00, 361.63it/s]\n",
      "4it [00:00, 359.38it/s]\n",
      "4it [00:00, 360.46it/s]\n",
      "4it [00:00, 355.67it/s]\n",
      "4it [00:00, 356.56it/s]\n",
      "4it [00:00, 358.51it/s]\n",
      "4it [00:00, 360.47it/s]\n",
      "4it [00:00, 359.57it/s]\n",
      "4it [00:00, 358.81it/s]\n",
      "4it [00:00, 360.69it/s]\n",
      "4it [00:00, 360.58it/s]\n",
      "4it [00:00, 357.33it/s]\n",
      "4it [00:00, 358.23it/s]\n",
      "4it [00:00, 359.55it/s]\n",
      "4it [00:00, 360.18it/s]\n",
      "4it [00:00, 352.28it/s]\n",
      "4it [00:00, 356.47it/s]\n",
      "4it [00:00, 289.61it/s]\n",
      "4it [00:00, 344.86it/s]\n",
      "4it [00:00, 248.81it/s]\n",
      "4it [00:00, 249.06it/s]\n",
      "4it [00:00, 336.66it/s]\n",
      "4it [00:00, 353.12it/s]\n",
      "4it [00:00, 350.25it/s]\n",
      "4it [00:00, 351.55it/s]\n",
      "4it [00:00, 360.83it/s]\n",
      "4it [00:00, 355.59it/s]\n",
      "4it [00:00, 355.04it/s]\n",
      "4it [00:00, 354.31it/s]\n",
      "4it [00:00, 350.29it/s]\n",
      "4it [00:00, 356.26it/s]\n",
      "4it [00:00, 353.47it/s]\n",
      "4it [00:00, 355.24it/s]\n",
      "4it [00:00, 352.88it/s]\n",
      "4it [00:00, 352.49it/s]\n",
      "4it [00:00, 356.39it/s]\n",
      "4it [00:00, 359.04it/s]\n",
      "4it [00:00, 355.63it/s]\n",
      "4it [00:00, 357.08it/s]\n",
      "4it [00:00, 349.72it/s]\n",
      "4it [00:00, 355.16it/s]\n",
      "4it [00:00, 353.97it/s]\n",
      "4it [00:00, 285.44it/s]\n",
      "4it [00:00, 356.64it/s]\n",
      "4it [00:00, 351.78it/s]\n",
      "4it [00:00, 357.85it/s]\n",
      "4it [00:00, 355.42it/s]\n",
      "4it [00:00, 352.28it/s]\n",
      "4it [00:00, 355.04it/s]\n",
      "4it [00:00, 350.69it/s]\n",
      "4it [00:00, 349.34it/s]\n",
      "4it [00:00, 351.92it/s]\n",
      "4it [00:00, 352.23it/s]\n",
      "4it [00:00, 349.14it/s]\n",
      "4it [00:00, 357.88it/s]\n",
      "4it [00:00, 354.36it/s]\n",
      "4it [00:00, 356.30it/s]\n",
      "4it [00:00, 353.04it/s]\n",
      "4it [00:00, 350.11it/s]\n",
      "4it [00:00, 353.29it/s]\n",
      "4it [00:00, 355.66it/s]\n",
      "4it [00:00, 352.93it/s]\n",
      "4it [00:00, 352.06it/s]\n",
      "4it [00:00, 353.79it/s]\n",
      "4it [00:00, 350.02it/s]\n",
      "4it [00:00, 353.12it/s]\n",
      "4it [00:00, 352.60it/s]\n",
      "4it [00:00, 350.71it/s]\n",
      "4it [00:00, 351.19it/s]\n",
      "4it [00:00, 349.69it/s]\n",
      "4it [00:00, 354.92it/s]\n",
      "4it [00:00, 353.15it/s]\n",
      "4it [00:00, 352.40it/s]\n",
      "4it [00:00, 353.79it/s]\n",
      "4it [00:00, 353.99it/s]\n",
      "4it [00:00, 346.54it/s]\n",
      "4it [00:00, 353.18it/s]\n",
      "4it [00:00, 351.55it/s]\n",
      "4it [00:00, 359.30it/s]\n",
      "4it [00:00, 351.62it/s]\n",
      "4it [00:00, 354.21it/s]\n",
      "4it [00:00, 354.26it/s]\n",
      "4it [00:00, 353.37it/s]\n",
      "4it [00:00, 353.48it/s]\n",
      "4it [00:00, 353.62it/s]\n",
      "4it [00:00, 352.71it/s]\n",
      "4it [00:00, 352.70it/s]\n",
      "4it [00:00, 348.50it/s]\n",
      "4it [00:00, 321.42it/s]\n",
      "4it [00:00, 345.93it/s]\n",
      "4it [00:00, 348.08it/s]\n",
      "4it [00:00, 351.94it/s]\n",
      "4it [00:00, 351.76it/s]\n",
      "4it [00:00, 352.82it/s]\n",
      "4it [00:00, 357.37it/s]\n",
      "4it [00:00, 352.74it/s]\n",
      "4it [00:00, 354.33it/s]\n",
      "4it [00:00, 350.48it/s]\n",
      "4it [00:00, 348.21it/s]\n",
      "4it [00:00, 348.38it/s]\n",
      "4it [00:00, 354.03it/s]\n",
      "4it [00:00, 344.58it/s]\n",
      "4it [00:00, 346.82it/s]\n",
      "4it [00:00, 350.94it/s]\n",
      "4it [00:00, 352.11it/s]\n",
      "4it [00:00, 347.05it/s]\n",
      "4it [00:00, 341.68it/s]\n",
      "4it [00:00, 342.75it/s]\n",
      "4it [00:00, 348.68it/s]\n",
      "4it [00:00, 343.13it/s]\n",
      "4it [00:00, 343.06it/s]\n",
      "4it [00:00, 343.35it/s]\n",
      "4it [00:00, 347.40it/s]\n",
      "4it [00:00, 350.15it/s]\n",
      "4it [00:00, 351.54it/s]\n",
      "4it [00:00, 350.18it/s]\n",
      "4it [00:00, 350.75it/s]\n",
      "4it [00:00, 347.51it/s]\n",
      "4it [00:00, 348.60it/s]\n",
      "4it [00:00, 349.85it/s]\n",
      "4it [00:00, 349.02it/s]\n",
      "4it [00:00, 346.72it/s]\n",
      "4it [00:00, 348.99it/s]\n",
      "4it [00:00, 347.12it/s]\n",
      "4it [00:00, 347.33it/s]\n",
      "4it [00:00, 343.94it/s]\n",
      "4it [00:00, 351.80it/s]\n",
      "4it [00:00, 339.51it/s]\n",
      "4it [00:00, 345.77it/s]\n",
      "4it [00:00, 320.71it/s]\n",
      "4it [00:00, 347.12it/s]\n",
      "4it [00:00, 299.13it/s]\n",
      "4it [00:00, 342.32it/s]\n",
      "4it [00:00, 341.15it/s]\n",
      "4it [00:00, 343.63it/s]\n",
      "4it [00:00, 332.51it/s]\n",
      "4it [00:00, 334.89it/s]\n",
      "4it [00:00, 332.36it/s]\n",
      "4it [00:00, 319.50it/s]\n",
      "4it [00:00, 334.13it/s]\n",
      "4it [00:00, 346.07it/s]\n",
      "4it [00:00, 347.97it/s]\n",
      "4it [00:00, 354.15it/s]\n",
      "4it [00:00, 349.47it/s]\n",
      "4it [00:00, 353.73it/s]\n",
      "4it [00:00, 349.39it/s]\n",
      "4it [00:00, 350.40it/s]\n",
      "4it [00:00, 343.61it/s]\n",
      "4it [00:00, 348.56it/s]\n",
      "4it [00:00, 343.45it/s]\n",
      "4it [00:00, 327.54it/s]\n",
      "4it [00:00, 338.26it/s]\n",
      "4it [00:00, 345.96it/s]\n",
      "4it [00:00, 349.45it/s]\n",
      "4it [00:00, 350.58it/s]\n",
      "4it [00:00, 349.25it/s]\n",
      "4it [00:00, 350.17it/s]\n",
      "4it [00:00, 348.97it/s]\n",
      "4it [00:00, 348.10it/s]\n",
      "4it [00:00, 344.23it/s]\n",
      "4it [00:00, 316.63it/s]\n",
      "4it [00:00, 331.28it/s]\n",
      "4it [00:00, 338.91it/s]\n",
      "4it [00:00, 340.05it/s]\n",
      "4it [00:00, 329.27it/s]\n",
      "4it [00:00, 323.06it/s]\n",
      "4it [00:00, 327.69it/s]\n",
      "4it [00:00, 324.32it/s]\n",
      "4it [00:00, 341.94it/s]\n",
      "4it [00:00, 346.87it/s]\n",
      "4it [00:00, 350.94it/s]\n",
      "4it [00:00, 348.59it/s]\n",
      "4it [00:00, 349.61it/s]\n",
      "4it [00:00, 351.08it/s]\n",
      "4it [00:00, 345.71it/s]\n",
      "4it [00:00, 349.37it/s]\n",
      "4it [00:00, 342.29it/s]\n",
      "4it [00:00, 350.42it/s]\n",
      "4it [00:00, 350.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Captions generated and saved to 'generated_captions.txt'.\n"
     ]
    }
   ],
   "source": [
    "# Loop through each element in emb_eeg_test and generate captions\n",
    "with open('semantic_level_caption.txt', 'w') as f:\n",
    "    for emb in emb_eeg_test:\n",
    "        # Generate h for each emb\n",
    "        h = pipe.generate(c_embeds=emb, num_inference_steps=4, guidance_scale=5.0)\n",
    "\n",
    "        # Get test image\n",
    "        test_img_257_1024 = model(h.unsqueeze(0).to(torch.bfloat16).to(device))\n",
    "\n",
    "        # Generate captions\n",
    "        generated_ids = clip_text_model.generate(pixel_values=test_img_257_1024.to(device).float(), max_new_tokens=25)\n",
    "        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "        print(generated_caption)\n",
    "        # Write each caption to a new line in the txt file\n",
    "        f.write(f\"{generated_caption[0]}\\n\")\n",
    "\n",
    "print(\"Captions generated and saved to 'generated_captions.txt'.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = EmbeddingDataset(\n",
    "#     c_embeddings=eeg_features_train, h_embeddings=emb_img_train_4, \n",
    "#     # h_embeds_uncond=h_embeds_imgnet\n",
    "# )\n",
    "# dl = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=64)\n",
    "diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)\n",
    "# number of parameters\n",
    "print(sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad))\n",
    "pipe = Pipe(diffusion_prior, device=device)\n",
    "\n",
    "# load pretrained model\n",
    "model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'\n",
    "# pipe.train(dl, num_epochs=150, learning_rate=1e-3) # to 0.142 \n",
    "pipe.diffusion_prior.load_state_dict(torch.load(f'{model_name}.pt', map_location=device))\n",
    "\n",
    "h = pipe.generate(c_embeds=emb_eeg_test[0], num_inference_steps=4, guidance_scale=5.0)\n",
    "\n",
    "\n",
    "model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
    "model.load_state_dict(torch.load('/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin'))\n",
    "\n",
    "\n",
    "test_img_257_1024 = model(h.unsqueeze(0).to(torch.bfloat16).to(device))\n",
    "\n",
    "generated_ids = clip_text_model.generate(pixel_values=test_img_257_1024.to(device).float(), max_new_tokens=25)\n",
    "\n",
    "generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "generated_caption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "# generated_caption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "BCI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.1.undefined"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}