168 lines (167 with data), 5.3 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"from transformers import CLIPVisionModel \n",
"import torch\n",
"from torch import nn\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from transformers import CLIPVisionModel\n",
"from torchvision import transforms\n",
"\n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_train.pt')['img_features']# \n",
"test_pixel_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-L-14_features_GIT_test.pt')['img_features']# \n",
"train_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_train.pt')['img_features'].unsqueeze(1)# \n",
"test_img_feature = torch.load('/root/autodl-tmp/Workspace/EEG_caption/ViT-H-14_features_test.pt')['img_features'].unsqueeze(1)# \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_img_feature.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"from einops.layers.torch import Rearrange, Reduce\n",
"\n",
"# Define the neural network\n",
"class PixelProjector(nn.Sequential):\n",
" def __init__(self, proj_dim=1024):\n",
" super().__init__(\n",
" Rearrange('B C L->B L C'), \n",
" nn.Linear(1, 257),\n",
" nn.LayerNorm(257),\n",
" Rearrange('B L C->B C L'),\n",
" nn.Linear(1024, 1024),\n",
" nn.LayerNorm(proj_dim),\n",
" )\n",
" \n",
" \n",
"\n",
"# Instantiate the model, loss function, and optimizer\n",
"\n",
"model = PixelProjector(proj_dim=1024).to(torch.bfloat16).to(device)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.AdamW(model.parameters(), lr=0.001)\n",
"\n",
"# Prepare data loaders\n",
"train_dataset = TensorDataset(train_img_feature, train_pixel_img_feature)\n",
"test_dataset = TensorDataset(test_img_feature, test_pixel_img_feature)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
"\n",
"# Training loop\n",
"num_epochs = 30\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
" running_loss = 0.0\n",
" for inputs, targets in train_loader:\n",
" inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, targets)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
" \n",
" print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}\")\n",
"\n",
"# Testing loop\n",
"model.eval()\n",
"test_loss = 0.0\n",
"with torch.no_grad():\n",
" for inputs, targets in test_loader:\n",
" inputs, targets = inputs.to(torch.bfloat16).to(device), targets.to(torch.bfloat16).to(device)\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, targets)\n",
" test_loss += loss.item()\n",
"\n",
"print(f\"Test Loss: {test_loss/len(test_loader)}\")\n",
"\n",
"# Save the trained model\n",
"torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n",
"print(\"Model saved as PixelProjector.bin\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model saved as PixelProjector.bin\n"
]
}
],
"source": [
"# Save the trained model\n",
"torch.save(model.state_dict(), '/root/autodl-tmp/Workspace/EEG_caption/model_weights/PixelProjector_best.bin')\n",
"print(\"Model saved as PixelProjector.bin\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "BCI",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}