--- a
+++ b/alignment_model.py
@@ -0,0 +1,72 @@
+# alignment_model.py
+
+import torch
+import torch.nn as nn
+from transformers import AutoModel, AutoTokenizer
+from typing import List, Tuple, Optional
+
+class ImageTextAlignmentModel(nn.Module):
+    def __init__(self, image_embedding_dim: int = 512, text_embedding_dim: Optional[int] = None):
+        super().__init__()
+
+        # Initialize BioGPT encoder and tokenizer
+        self.text_encoder = AutoModel.from_pretrained('microsoft/biogpt')
+        self.tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
+
+        if text_embedding_dim is None:
+            text_embedding_dim = self.text_encoder.config.hidden_size
+
+        # Projection networks with layer normalization
+        self.image_projection = nn.Sequential(
+            nn.Linear(image_embedding_dim, text_embedding_dim),
+            nn.LayerNorm(text_embedding_dim),
+            nn.GELU(),
+        )
+
+        self.text_projection = nn.Sequential(
+            nn.Linear(text_embedding_dim, text_embedding_dim),
+            nn.LayerNorm(text_embedding_dim),
+            nn.GELU(),
+        )
+
+        # Initialize weights
+        self._init_weights()
+
+    def _init_weights(self):
+        """Initialize weights with Xavier uniform distribution"""
+        for module in self.modules():
+            if isinstance(module, nn.Linear):
+                nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.zeros_(module.bias)
+
+    def encode_text(self, text: List[str], device: torch.device) -> torch.Tensor:
+        """Encode text using BioGPT"""
+        # Tokenize and encode text
+        text_encoding = self.tokenizer(
+            text,
+            padding=True,
+            truncation=True,
+            return_tensors="pt",
+            max_length=512
+        ).to(device)
+
+        # Get text features
+        with torch.no_grad():
+            text_outputs = self.text_encoder(**text_encoding)
+            text_features = text_outputs.last_hidden_state[:, 0, :]  # Take [CLS] token
+
+        return text_features
+
+    def forward(self, image_embeddings: torch.Tensor, text: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Get device
+        device = image_embeddings.device
+
+        # Encode text
+        text_features = self.encode_text(text, device)
+
+        # Project features
+        projected_image = self.image_projection(image_embeddings)
+        projected_text = self.text_projection(text_features)
+
+        return projected_image, projected_text