--- a +++ b/Retrieval/eegdatasets_leaveone.py @@ -0,0 +1,407 @@ +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np +import os +import clip +from torch.nn import functional as F +import torch.nn as nn +from torchvision import transforms +from PIL import Image +import requests + +import os +proxy = 'http://127.0.0.1:7890' +os.environ['http_proxy'] = proxy +os.environ['https_proxy'] = proxy +cuda_device_count = torch.cuda.device_count() +print(cuda_device_count) +device = "cuda:0" if torch.cuda.is_available() else "cpu" +# vlmodel, preprocess = clip.load("ViT-B/32", device=device) +model_type = 'ViT-H-14' +import open_clip +vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms( + model_type, pretrained='laion2b_s32b_b79k', precision='fp32', device = device) + +import json + +# Load the configuration from the JSON file +config_path = "data_config.json" +with open(config_path, "r") as config_file: + config = json.load(config_file) + +# Access the paths from the config +data_path = config["data_path"] +img_directory_training = config["img_directory_training"] +img_directory_test = config["img_directory_test"] + + +class EEGDataset(): + """ + subjects = ['sub-01', 'sub-02', 'sub-05', 'sub-04', 'sub-03', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10'] + """ + def __init__(self, data_path, exclude_subject=None, subjects=None, train=True, time_window=[0, 1.0], classes = None, pictures = None, val_size=None): + self.data_path = data_path + self.train = train + self.subject_list = os.listdir(data_path) + self.subjects = self.subject_list if subjects is None else subjects + self.n_sub = len(self.subjects) + self.time_window = time_window + self.n_cls = 1654 if train else 200 + self.classes = classes + self.pictures = pictures + self.exclude_subject = exclude_subject + self.val_size = val_size + # assert any subjects in subject_list + assert any(sub in self.subject_list for sub in self.subjects) + + self.data, self.labels, self.text, self.img = self.load_data() + + self.data = self.extract_eeg(self.data, time_window) + + + if self.classes is None and self.pictures is None: + # Try to load the saved features if they exist + features_filename = os.path.join(f'{model_type}_features_train.pt') if self.train else os.path.join(f'{model_type}_features_test.pt') + + if os.path.exists(features_filename) : + saved_features = torch.load(features_filename) + self.text_features = saved_features['text_features'] + self.img_features = saved_features['img_features'] + else: + self.text_features = self.Textencoder(self.text) + self.img_features = self.ImageEncoder(self.img) + torch.save({ + 'text_features': self.text_features.cpu(), + 'img_features': self.img_features.cpu(), + }, features_filename) + else: + self.text_features = self.Textencoder(self.text) + self.img_features = self.ImageEncoder(self.img) + + def load_data(self): + data_list = [] + label_list = [] + texts = [] + images = [] + + if self.train: + directory = img_directory_training + else: + directory = img_directory_test + + dirnames = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] + dirnames.sort() + + if self.classes is not None: + dirnames = [dirnames[i] for i in self.classes] + + for dir in dirnames: + + try: + idx = dir.index('_') + description = dir[idx+1:] + except ValueError: + print(f"Skipped: {dir} due to no '_' found.") + continue + + new_description = f"This picture is {description}" + texts.append(new_description) + + if self.train: + img_directory = img_directory_training + else: + img_directory = img_directory_test + + all_folders = [d for d in os.listdir(img_directory) if os.path.isdir(os.path.join(img_directory, d))] + all_folders.sort() + + if self.classes is not None and self.pictures is not None: + images = [] + for i in range(len(self.classes)): + class_idx = self.classes[i] + pic_idx = self.pictures[i] + if class_idx < len(all_folders): + folder = all_folders[class_idx] + folder_path = os.path.join(img_directory, folder) + all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] + all_images.sort() + if pic_idx < len(all_images): + images.append(os.path.join(folder_path, all_images[pic_idx])) + elif self.classes is not None and self.pictures is None: + images = [] + for i in range(len(self.classes)): + class_idx = self.classes[i] + if class_idx < len(all_folders): + folder = all_folders[class_idx] + folder_path = os.path.join(img_directory, folder) + all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] + all_images.sort() + images.extend(os.path.join(folder_path, img) for img in all_images) + elif self.classes is None: + images = [] + for folder in all_folders: + folder_path = os.path.join(img_directory, folder) + all_images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))] + all_images.sort() + images.extend(os.path.join(folder_path, img) for img in all_images) + else: + + print("Error") + + print("self.subjects", self.subjects) + print("exclude_subject", self.exclude_subject) + for subject in self.subjects: + if self.train: + if subject == self.exclude_subject: + continue + # print("subject:", subject) + file_name = 'preprocessed_eeg_training.npy' + + file_path = os.path.join(self.data_path, subject, file_name) + data = np.load(file_path, allow_pickle=True) + + preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() + times = torch.from_numpy(data['times']).detach()[50:] + ch_names = data['ch_names'] + + n_classes = 1654 + samples_per_class = 10 + + if self.classes is not None and self.pictures is not None: + for c, p in zip(self.classes, self.pictures): + start_index = c * 1 + p + if start_index < len(preprocessed_eeg_data): + preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+1] + labels = torch.full((1,), c, dtype=torch.long).detach() + data_list.append(preprocessed_eeg_data_class) + label_list.append(labels) + + elif self.classes is not None and self.pictures is None: + for c in self.classes: + start_index = c * samples_per_class + preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] + labels = torch.full((samples_per_class,), c, dtype=torch.long).detach() + data_list.append(preprocessed_eeg_data_class) + label_list.append(labels) + + else: + for i in range(n_classes): + start_index = i * samples_per_class + # if self.exclude_subject==None: + # preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] + # else: + preprocessed_eeg_data_class = preprocessed_eeg_data[start_index: start_index+samples_per_class] + # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) + # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 1) + # preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class, 0) + # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) + labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() + data_list.append(preprocessed_eeg_data_class) + label_list.append(labels) + + + else: + if subject == self.exclude_subject or self.exclude_subject==None: + file_name = 'preprocessed_eeg_test.npy' + file_path = os.path.join(self.data_path, subject, file_name) + data = np.load(file_path, allow_pickle=True) + preprocessed_eeg_data = torch.from_numpy(data['preprocessed_eeg_data']).float().detach() + times = torch.from_numpy(data['times']).detach()[50:] + ch_names = data['ch_names'] + n_classes = 200 # Each class contains 1 images + + samples_per_class = 1 + + for i in range(n_classes): + if self.classes is not None and i not in self.classes: # If we've defined specific classes and the current class is not in the list, skip + continue + start_index = i * samples_per_class # Update start_index for each class + preprocessed_eeg_data_class = preprocessed_eeg_data[start_index:start_index+samples_per_class] + # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) + labels = torch.full((samples_per_class,), i, dtype=torch.long).detach() # Add class labels + preprocessed_eeg_data_class = torch.mean(preprocessed_eeg_data_class.squeeze(0), 0) + # print("preprocessed_eeg_data_class", preprocessed_eeg_data_class.shape) + data_list.append(preprocessed_eeg_data_class) + label_list.append(labels) # Add labels to the label list + else: + continue + # datalist: (subjects * classes) * (10 * 4 * 17 * 100) + # data_tensor: (subjects * classes * 10 * 4) * 17 * 100 + # data_list = np.mean(data_list, ) + # print("data_list", len(data_list)) + if self.train: + # print("data_list", *data_list[0].shape[1:]) + data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) + # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[1:]) + # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) + # print("label_tensor", label_tensor.shape) + print("data_tensor", data_tensor.shape) + else: + data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape) + # label_tensor = torch.cat(label_list, dim=0) + # print("label_tensor", label_tensor.shape) + # data_tensor = torch.cat(data_list, dim=0).view(-1, *data_list[0].shape[2:]) + # print("data_tensor", data_tensor.shape) + # label_list: (subjects * classes) * 10 + # label_tensor: (subjects * classes * 10) + # print("label_tensor = torch.cat(label_list, dim=0)") + # print(label_list) + label_tensor = torch.cat(label_list, dim=0) + # label_tensor = torch.cat(label_list, dim=0) + # print(label_tensor[:300]) + if self.train: + # label_tensor: (subjects * classes * 10 * 4) + label_tensor = label_tensor.repeat_interleave(4) + if self.classes is not None: + unique_values = list(label_tensor.numpy()) + lis = [] + for i in unique_values: + if i not in lis: + lis.append(i) + unique_values = torch.tensor(lis) + mapping = {val.item(): index for index, val in enumerate(unique_values)} + label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) + + else: + # label_tensor = label_tensor.repeat_interleave(80) + # if self.classes is not None: + # unique_values = torch.unique(label_tensor, sorted=False) + + # mapping = {val.item(): index for index, val in enumerate(torch.flip(unique_values, [0]))} + # label_tensor = torch.tensor([mapping[val.item()] for val in label_tensor], dtype=torch.long) + pass + + + self.times = times + self.ch_names = ch_names + + print(f"Data tensor shape: {data_tensor.shape}, label tensor shape: {label_tensor.shape}, text length: {len(texts)}, image length: {len(images)}") + + return data_tensor, label_tensor, texts, images + + def extract_eeg(self, eeg_data, time_window): + + start, end = time_window + + # Get the indices of the times within the specified window + indices = (self.times >= start) & (self.times <= end) + # print("self.times", self.times.shape) + # print("indices", indices) + # print("indices", indices.shape) + # print("eeg_data", eeg_data.shape) + # Use these indices to select the corresponding data + extracted_data = eeg_data[..., indices] + # print(f"extracted_data shape: {extracted_data.shape}") + + return extracted_data + + def Textencoder(self, text): + + text_inputs = torch.cat([clip.tokenize(t) for t in text]).to(device) + # print("text_inputs", text_inputs) + + with torch.no_grad(): + text_features = vlmodel.encode_text(text_inputs) + + text_features = F.normalize(text_features, dim=-1).detach() + + return text_features + + def ImageEncoder(self,images): + batch_size = 20 + image_features_list = [] + + for i in range(0, len(images), batch_size): + batch_images = images[i:i + batch_size] + image_inputs = torch.stack([preprocess_train(Image.open(img).convert("RGB")) for img in batch_images]).to(device) + + with torch.no_grad(): + batch_image_features = vlmodel.encode_image(image_inputs) + batch_image_features /= batch_image_features.norm(dim=-1, keepdim=True) + + image_features_list.append(batch_image_features) + + image_features = torch.cat(image_features_list, dim=0) + + return image_features + + def __getitem__(self, index): + # Get the data and label corresponding to "index" + # index: (subjects * classes * 10 * 4) + x = self.data[index] + label = self.labels[index] + + if self.pictures is None: + if self.classes is None: + index_n_sub_train = self.n_cls * 10 * 4 + index_n_sub_test = self.n_cls * 1 * 80 + else: + index_n_sub_test = len(self.classes)* 1 * 80 + index_n_sub_train = len(self.classes)* 10 * 4 + # text_index: classes + if self.train: + text_index = (index % index_n_sub_train) // (10 * 4) + else: + text_index = (index % index_n_sub_test) + # img_index: classes * 10 + if self.train: + img_index = (index % index_n_sub_train) // (4) + else: + img_index = (index % index_n_sub_test) + else: + if self.classes is None: + index_n_sub_train = self.n_cls * 1 * 4 + index_n_sub_test = self.n_cls * 1 * 80 + else: + index_n_sub_test = len(self.classes)* 1 * 80 + index_n_sub_train = len(self.classes)* 1 * 4 + # text_index: classes + if self.train: + text_index = (index % index_n_sub_train) // (1 * 4) + else: + text_index = (index % index_n_sub_test) + # img_index: classes * 10 + if self.train: + img_index = (index % index_n_sub_train) // (4) + else: + img_index = (index % index_n_sub_test) + # print("text_index", text_index) + # print("self.text", self.text) + # print("self.text", len(self.text)) + text = self.text[text_index] + img = self.img[img_index] + + text_features = self.text_features[text_index] + img_features = self.img_features[img_index] + + return x, label, text, text_features, img, img_features + + def __len__(self): + return self.data.shape[0] # or self.labels.shape[0] which should be the same + +if __name__ == "__main__": + # Instantiate the dataset and dataloader + # data_path = "/home/ldy/Workspace/THINGS/EEG/osfstorage-archive" # Replace with the path to your data + data_path = data_path + train_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=True) + test_dataset = EEGDataset(data_path, subjects = ['sub-01'], train=False) + # train_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=True) + # test_dataset = EEGDataset(data_path, exclude_subject = 'sub-01', train=False) + # train_dataset = EEGDataset(data_path, train=True) + # test_dataset = EEGDataset(data_path, train=False) + + + + + # 100 Hz + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) + + i = 80*1-1 + x, label, text, text_features, img, img_features = test_dataset[i] + print(f"Index {i}, Label: {label}, text: {text}") + Image.open(img) + + + + \ No newline at end of file