--- a +++ b/alignment_model.py @@ -0,0 +1,72 @@ +# alignment_model.py + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer +from typing import List, Tuple, Optional + +class ImageTextAlignmentModel(nn.Module): + def __init__(self, image_embedding_dim: int = 512, text_embedding_dim: Optional[int] = None): + super().__init__() + + # Initialize BioGPT encoder and tokenizer + self.text_encoder = AutoModel.from_pretrained('microsoft/biogpt') + self.tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') + + if text_embedding_dim is None: + text_embedding_dim = self.text_encoder.config.hidden_size + + # Projection networks with layer normalization + self.image_projection = nn.Sequential( + nn.Linear(image_embedding_dim, text_embedding_dim), + nn.LayerNorm(text_embedding_dim), + nn.GELU(), + ) + + self.text_projection = nn.Sequential( + nn.Linear(text_embedding_dim, text_embedding_dim), + nn.LayerNorm(text_embedding_dim), + nn.GELU(), + ) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + """Initialize weights with Xavier uniform distribution""" + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def encode_text(self, text: List[str], device: torch.device) -> torch.Tensor: + """Encode text using BioGPT""" + # Tokenize and encode text + text_encoding = self.tokenizer( + text, + padding=True, + truncation=True, + return_tensors="pt", + max_length=512 + ).to(device) + + # Get text features + with torch.no_grad(): + text_outputs = self.text_encoder(**text_encoding) + text_features = text_outputs.last_hidden_state[:, 0, :] # Take [CLS] token + + return text_features + + def forward(self, image_embeddings: torch.Tensor, text: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + # Get device + device = image_embeddings.device + + # Encode text + text_features = self.encode_text(text, device) + + # Project features + projected_image = self.image_projection(image_embeddings) + projected_text = self.text_projection(text_features) + + return projected_image, projected_text