--- a
+++ b/src/data/data_pipeline.py
@@ -0,0 +1,169 @@
+import numpy as np
+import pandas as pd
+from typing import Dict, List, Tuple, Optional, Generator
+from sklearn.model_selection import StratifiedKFold
+from sklearn.utils import shuffle
+import joblib
+from pathlib import Path
+import os
+from src.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class DataPipeline:
+    """
+    Data pipeline for model training with advanced data handling capabilities
+
+    Features:
+    - Stratified sampling
+    - Cross-validation splits
+    - Data caching
+    - Batch generation
+    """
+
+    def __init__(self,
+                 data_dir: str,
+                 batch_size: int = 32,
+                 n_splits: int = 5,
+                 cache_dir: Optional[str] = None,
+                 random_state: int = 42):
+        # Initialize data pipeline
+        self.data_dir = Path(data_dir)
+        self.batch_size = batch_size
+        self.n_splits = n_splits
+        self.cache_dir = Path(cache_dir) if cache_dir else None
+        self.random_state = random_state
+        self.logger = get_logger(self.__class__.__name__)
+
+        # Load prepared data
+        self.load_data()
+
+        # Initialize cross-validation splitter
+        self.cv = StratifiedKFold(
+            n_splits=self.n_splits,
+            shuffle=True,
+            random_state=self.random_state
+        )
+
+    def load_data(self):
+        """Load prepared data from files"""
+        try:
+            # Load features and labels for each split
+            self.data = {}
+            for split in ['train', 'val', 'test']:
+                features_path = self.data_dir / f'{split}_features.npy'
+                labels_path = self.data_dir / f'{split}_labels.npy'
+                features = np.load(features_path)
+                labels = np.load(labels_path)
+                self.data[split] = (features, labels)
+
+            # Load metadata
+            metadata_path = self.data_dir / 'metadata.joblib'
+            self.metadata = joblib.load(metadata_path)
+
+            self.logger.info("Data loaded successfully")
+            self._log_data_info()
+
+        except Exception as e:
+            self.logger.error(f"Error loading data: {str(e)}")
+            raise
+
+    def _log_data_info(self):
+        """Log information about loaded data"""
+        for split, (features, labels) in self.data.items():
+            self.logger.info(f"\n{split.capitalize()} set:")
+            self.logger.info(f"Features shape: {features.shape}")
+            self.logger.info(f"Labels shape: {labels.shape}")
+            self.logger.info(f"Classes: {np.unique(labels)}")
+
+    def get_cv_splits(self) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
+        """Generate cross-validation splits"""
+        features, labels = self.data['train']
+        for fold, (train_idx, val_idx) in enumerate(self.cv.split(features, labels)):
+            self.logger.info(f"Generating split for fold {fold + 1}/{self.n_splits}")
+            yield train_idx, val_idx
+
+    def get_batch_generator(self,
+                            split: str,
+                            batch_size: Optional[int] = None,
+                            shuffle_data: bool = True) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
+        """Generate batches of data"""
+        features, labels = self.data[split]
+        batch_size = batch_size if batch_size is not None else self.batch_size
+
+        # Create cache key if caching is enabled
+        if self.cache_dir:
+            cache_key = f"{split}_batch_{batch_size}"
+            if not self._check_cache(cache_key):
+                # Save the original, unshuffled data to cache.
+                self._save_to_cache(cache_key, features, labels)
+            else:
+                features, labels = self._load_from_cache(cache_key)
+
+        # Shuffle if requested (the caching preserves the original order)
+        if shuffle_data:
+            features, labels = shuffle(features, labels, random_state=self.random_state)
+
+        num_samples = len(features)
+        num_batches = (num_samples + batch_size - 1) // batch_size
+
+        for i in range(num_batches):
+            start_idx = i * batch_size
+            end_idx = min(start_idx + batch_size, num_samples)
+            yield features[start_idx:end_idx], labels[start_idx:end_idx]
+
+    def get_all_data(self, split: str) -> Tuple[np.ndarray, np.ndarray]:
+        """Get all data for a split"""
+        return self.data[split]
+
+    def _check_cache(self, key: str) -> bool:
+        """Check if data exists in cache"""
+        if not self.cache_dir:
+            return False
+
+        features_path = self.cache_dir / f"{key}_features.npy"
+        labels_path = self.cache_dir / f"{key}_labels.npy"
+        return features_path.exists() and labels_path.exists()
+
+    def _load_from_cache(self, key: str) -> Tuple[np.ndarray, np.ndarray]:
+        """Load data from cache"""
+        features = np.load(self.cache_dir / f"{key}_features.npy")
+        labels = np.load(self.cache_dir / f"{key}_labels.npy")
+        self.logger.info(f"Loaded cache with key: {key}")
+        return features, labels
+
+    def _save_to_cache(self, key: str, features: np.ndarray, labels: np.ndarray):
+        """Save data to cache"""
+        if self.cache_dir:
+            self.cache_dir.mkdir(parents=True, exist_ok=True)
+            np.save(self.cache_dir / f"{key}_features.npy", features)
+            np.save(self.cache_dir / f"{key}_labels.npy", labels)
+            self.logger.info(f"Saved cache with key: {key}")
+
+
+# Example usage and testing
+if __name__ == "__main__":
+    # Initialize pipeline - adjust paths as needed
+    pipeline = DataPipeline(
+        data_dir='../../data/prepared_data',
+        batch_size=32,
+        n_splits=5,
+        cache_dir='../../data/cache'
+    )
+
+    # Test cross-validation splits
+    logger.info("\nTesting cross-validation splits:")
+    for fold, (train_idx, val_idx) in enumerate(pipeline.get_cv_splits()):
+        logger.info(f"Fold {fold + 1}:")
+        logger.info(f"Training samples: {len(train_idx)}")
+        logger.info(f"Validation samples: {len(val_idx)}")
+
+    # Test batch generator
+    logger.info("\nTesting batch generator:")
+    for split in ['train', 'val', 'test']:
+        batch_gen = pipeline.get_batch_generator(split)
+        num_batches = 0
+        for features_batch, labels_batch in batch_gen:
+            num_batches += 1
+        logger.info(f"{split.capitalize()} batches generated: {num_batches}")