Switch to side-by-side view

--- a
+++ b/Generation/image_adapter.ipynb
@@ -0,0 +1,167 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from functools import partial\n",
+    "\n",
+    "from transformers import CLIPVisionModel \n",
+    "import torch\n",
+    "from torch import nn\n",
+    "from torchvision import transforms\n",
+    "from PIL import Image\n",
+    "\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "from transformers import CLIPVisionModel\n",
+    "from torchvision import transforms\n",
+    "\n",
+    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_train.pt')['img_features']# \n",
+    "test_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_test.pt')['img_features']# \n",
+    "train_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_train.pt')['img_features'].unsqueeze(1)# \n",
+    "test_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_test.pt')['img_features'].unsqueeze(1)# \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_img_feature.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.optim as optim\n",
+    "from torch.utils.data import DataLoader, TensorDataset\n",
+    "from einops.layers.torch import Rearrange, Reduce\n",
+    "\n",
+    "# Define the neural network\n",
+    "class PixelProjector(nn.Sequential):\n",
+    "    def __init__(self, proj_dim=1024):\n",
+    "        super().__init__(\n",
+    "            Rearrange('B C L->B L C'),    \n",
+    "            nn.Linear(1, 257),\n",
+    "            nn.LayerNorm(257),\n",
+    "            Rearrange('B L C->B C L'),\n",
+    "            nn.Linear(1024, 1024),\n",
+    "            nn.LayerNorm(proj_dim),\n",
+    "            )\n",
+    "        \n",
+    "        \n",
+    "\n",
+    "# Instantiate the model, loss function, and optimizer\n",
+    "\n",
+    "model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
+    "criterion = nn.MSELoss()\n",
+    "optimizer = optim.AdamW(model.parameters(), lr=0.001)\n",
+    "\n",
+    "# Prepare data loaders\n",
+    "train_dataset = TensorDataset(train_img_feature, train_pixel_img_feature)\n",
+    "test_dataset = TensorDataset(test_img_feature, test_pixel_img_feature)\n",
+    "\n",
+    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)\n",
+    "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
+    "\n",
+    "# Training loop\n",
+    "num_epochs = 30\n",
+    "for epoch in range(num_epochs):\n",
+    "    model.train()\n",
+    "    running_loss = 0.0\n",
+    "    for inputs, targets in train_loader:\n",
+    "        inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n",
+    "        optimizer.zero_grad()\n",
+    "        outputs = model(inputs)\n",
+    "        loss = criterion(outputs, targets)\n",
+    "        loss.backward()\n",
+    "        optimizer.step()\n",
+    "        running_loss += loss.item()\n",
+    "    \n",
+    "    print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}\")\n",
+    "\n",
+    "# Testing loop\n",
+    "model.eval()\n",
+    "test_loss = 0.0\n",
+    "with torch.no_grad():\n",
+    "    for inputs, targets in test_loader:\n",
+    "        inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n",
+    "        outputs = model(inputs)\n",
+    "        loss = criterion(outputs, targets)\n",
+    "        test_loss += loss.item()\n",
+    "\n",
+    "print(f\"Test Loss: {test_loss/len(test_loader)}\")\n",
+    "\n",
+    "# Save the trained model\n",
+    "torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n",
+    "print(\"Model saved as PixelProjector.bin\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Model saved as PixelProjector.bin\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Save the trained model\n",
+    "torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n",
+    "print(\"Model saved as PixelProjector.bin\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "BCI",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}