--- a +++ b/src/data/data_pipeline.py @@ -0,0 +1,169 @@ +import numpy as np +import pandas as pd +from typing import Dict, List, Tuple, Optional, Generator +from sklearn.model_selection import StratifiedKFold +from sklearn.utils import shuffle +import joblib +from pathlib import Path +import os +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class DataPipeline: + """ + Data pipeline for model training with advanced data handling capabilities + + Features: + - Stratified sampling + - Cross-validation splits + - Data caching + - Batch generation + """ + + def __init__(self, + data_dir: str, + batch_size: int = 32, + n_splits: int = 5, + cache_dir: Optional[str] = None, + random_state: int = 42): + # Initialize data pipeline + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.n_splits = n_splits + self.cache_dir = Path(cache_dir) if cache_dir else None + self.random_state = random_state + self.logger = get_logger(self.__class__.__name__) + + # Load prepared data + self.load_data() + + # Initialize cross-validation splitter + self.cv = StratifiedKFold( + n_splits=self.n_splits, + shuffle=True, + random_state=self.random_state + ) + + def load_data(self): + """Load prepared data from files""" + try: + # Load features and labels for each split + self.data = {} + for split in ['train', 'val', 'test']: + features_path = self.data_dir / f'{split}_features.npy' + labels_path = self.data_dir / f'{split}_labels.npy' + features = np.load(features_path) + labels = np.load(labels_path) + self.data[split] = (features, labels) + + # Load metadata + metadata_path = self.data_dir / 'metadata.joblib' + self.metadata = joblib.load(metadata_path) + + self.logger.info("Data loaded successfully") + self._log_data_info() + + except Exception as e: + self.logger.error(f"Error loading data: {str(e)}") + raise + + def _log_data_info(self): + """Log information about loaded data""" + for split, (features, labels) in self.data.items(): + self.logger.info(f"\n{split.capitalize()} set:") + self.logger.info(f"Features shape: {features.shape}") + self.logger.info(f"Labels shape: {labels.shape}") + self.logger.info(f"Classes: {np.unique(labels)}") + + def get_cv_splits(self) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: + """Generate cross-validation splits""" + features, labels = self.data['train'] + for fold, (train_idx, val_idx) in enumerate(self.cv.split(features, labels)): + self.logger.info(f"Generating split for fold {fold + 1}/{self.n_splits}") + yield train_idx, val_idx + + def get_batch_generator(self, + split: str, + batch_size: Optional[int] = None, + shuffle_data: bool = True) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: + """Generate batches of data""" + features, labels = self.data[split] + batch_size = batch_size if batch_size is not None else self.batch_size + + # Create cache key if caching is enabled + if self.cache_dir: + cache_key = f"{split}_batch_{batch_size}" + if not self._check_cache(cache_key): + # Save the original, unshuffled data to cache. + self._save_to_cache(cache_key, features, labels) + else: + features, labels = self._load_from_cache(cache_key) + + # Shuffle if requested (the caching preserves the original order) + if shuffle_data: + features, labels = shuffle(features, labels, random_state=self.random_state) + + num_samples = len(features) + num_batches = (num_samples + batch_size - 1) // batch_size + + for i in range(num_batches): + start_idx = i * batch_size + end_idx = min(start_idx + batch_size, num_samples) + yield features[start_idx:end_idx], labels[start_idx:end_idx] + + def get_all_data(self, split: str) -> Tuple[np.ndarray, np.ndarray]: + """Get all data for a split""" + return self.data[split] + + def _check_cache(self, key: str) -> bool: + """Check if data exists in cache""" + if not self.cache_dir: + return False + + features_path = self.cache_dir / f"{key}_features.npy" + labels_path = self.cache_dir / f"{key}_labels.npy" + return features_path.exists() and labels_path.exists() + + def _load_from_cache(self, key: str) -> Tuple[np.ndarray, np.ndarray]: + """Load data from cache""" + features = np.load(self.cache_dir / f"{key}_features.npy") + labels = np.load(self.cache_dir / f"{key}_labels.npy") + self.logger.info(f"Loaded cache with key: {key}") + return features, labels + + def _save_to_cache(self, key: str, features: np.ndarray, labels: np.ndarray): + """Save data to cache""" + if self.cache_dir: + self.cache_dir.mkdir(parents=True, exist_ok=True) + np.save(self.cache_dir / f"{key}_features.npy", features) + np.save(self.cache_dir / f"{key}_labels.npy", labels) + self.logger.info(f"Saved cache with key: {key}") + + +# Example usage and testing +if __name__ == "__main__": + # Initialize pipeline - adjust paths as needed + pipeline = DataPipeline( + data_dir='../../data/prepared_data', + batch_size=32, + n_splits=5, + cache_dir='../../data/cache' + ) + + # Test cross-validation splits + logger.info("\nTesting cross-validation splits:") + for fold, (train_idx, val_idx) in enumerate(pipeline.get_cv_splits()): + logger.info(f"Fold {fold + 1}:") + logger.info(f"Training samples: {len(train_idx)}") + logger.info(f"Validation samples: {len(val_idx)}") + + # Test batch generator + logger.info("\nTesting batch generator:") + for split in ['train', 'val', 'test']: + batch_gen = pipeline.get_batch_generator(split) + num_batches = 0 + for features_batch, labels_batch in batch_gen: + num_batches += 1 + logger.info(f"{split.capitalize()} batches generated: {num_batches}")