Switch to side-by-side view

--- a
+++ b/src/features/text_statistics.py
@@ -0,0 +1,199 @@
+import numpy as np
+import re
+from typing import List, Dict, Optional, Tuple
+from collections import defaultdict
+from src.preprocessing.preprocessing import create_ordered_medical_pipeline
+from src.utils.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class TextStatisticsExtractor:
+    """Extract statistical features from medical texts based on EDA findings"""
+
+    def __init__(self, disease_category: Optional[str] = None):
+        self.disease_category = disease_category
+        self.logger = get_logger(self.__class__.__name__)
+
+        # Initialize preprocessing pipeline
+        self.preprocessor = create_ordered_medical_pipeline(
+            disease_category=disease_category
+        )
+
+        # Measurement patterns from EDA
+        self.measurement_patterns = {
+            'scores': r'\d+(?:\s*(?:points?|score))',
+            'percentages': r'\d+(?:\.\d+)?\s*%',
+            'ranges': r'\d+\s*(?:-|to)\s*\d+',
+            'units': r'\d+(?:\.\d+)?\s*(?:mg|kg|ml|cm|mm)',
+            'plus_minus': r'\d+\s*±\s*\d+'
+        }
+
+        # Disease-specific patterns from EDA
+        self.disease_patterns = {
+            'ALS': {
+                'scores': r'(?:ALSFRS-R|FVC)',
+                'measurements': r'\d+\s*(?:fvc|alsfrs)',
+                'time_patterns': r'(?:months?|years?)\s*(?:decline|progression)'
+            },
+            'OCD': {
+                'scores': r'(?:Y-BOCS|severity)',
+                'measurements': r'\d+\s*(?:ybocs|severity)',
+                'time_patterns': r'(?:frequency|duration)\s*of\s*(?:symptoms|behaviors)'
+            },
+            'Parkinson': {
+                'scores': r'(?:UPDRS|Hoehn)',
+                'measurements': r'\d+\s*(?:updrs|stage)',
+                'time_patterns': r'(?:onset|progression|duration)'
+            },
+            'Dementia': {
+                'scores': r'(?:MMSE|CDR)',
+                'measurements': r'\d+\s*(?:mmse|cdr)',
+                'time_patterns': r'(?:months?|years?)\s*(?:decline|progression)'
+            },
+            'Scoliosis': {
+                'scores': r'(?:Cobb|curve)',
+                'measurements': r'\d+\s*(?:degree|angle)',
+                'time_patterns': r'(?:growth|progression|correction)'
+            }
+        }
+
+    def extract_basic_statistics(self, text: str) -> Dict[str, float]:
+        """Extract basic text statistics"""
+        # Preprocess text
+        processed = self.preprocessor.process(text)
+        if isinstance(processed, tuple):
+            processed = processed[0]
+
+        words = processed.split()
+        sentences = [s.strip() for s in processed.split('.') if s.strip()]
+
+        return {
+            'word_count': len(words),
+            'sentence_count': len(sentences),
+            'avg_sentence_length': len(words) / len(sentences) if sentences else 0,
+            'avg_word_length': sum(len(w) for w in words) / len(words) if words else 0
+        }
+
+    def extract_measurement_statistics(self, text: str) -> Dict[str, float]:
+        """Extract measurement-related statistics"""
+        stats = {}
+
+        # Count measurements by type
+        for name, pattern in self.measurement_patterns.items():
+            matches = re.finditer(pattern, text, re.IGNORECASE)
+            stats[f'{name}_count'] = sum(1 for _ in matches)
+
+        # Calculate measurement density
+        total_measurements = sum(stats.values())
+        words = text.split()
+        stats['measurement_density'] = total_measurements / len(words) if words else 0
+
+        return stats
+
+    def extract_disease_specific_statistics(self, text: str) -> Dict[str, float]:
+        """Extract disease-specific statistics"""
+        stats = {}
+
+        if self.disease_category and self.disease_category in self.disease_patterns:
+            patterns = self.disease_patterns[self.disease_category]
+
+            for name, pattern in patterns.items():
+                matches = re.finditer(pattern, text, re.IGNORECASE)
+                stats[f'{self.disease_category.lower()}_{name}_count'] = sum(1 for _ in matches)
+
+        return stats
+
+    def extract_readability_statistics(self, text: str) -> Dict[str, float]:
+        """Extract readability statistics"""
+        words = text.split()
+        sentences = [s.strip() for s in text.split('.') if s.strip()]
+
+        # Count syllables (simple approximation)
+        def count_syllables(word):
+            return len(re.findall(r'[aeiou]+', word.lower())) + 1
+
+        syllable_counts = [count_syllables(w) for w in words]
+
+        stats = {
+            'avg_syllables_per_word': sum(syllable_counts) / len(words) if words else 0,
+            'complex_words_ratio': sum(1 for c in syllable_counts if c > 2) / len(words) if words else 0
+        }
+
+        # Approximate Flesch Reading Ease
+        if words and sentences:
+            stats['flesch_reading_ease'] = 206.835 - 1.015 * (len(words) / len(sentences)) - 84.6 * (
+                    sum(syllable_counts) / len(words))
+        else:
+            stats['flesch_reading_ease'] = 0
+
+        return stats
+
+    def extract_all_statistics(self, text: str) -> Dict[str, float]:
+        """Extract all statistical features"""
+        # Collect all statistics
+        stats = {}
+
+        # Basic statistics
+        stats.update(self.extract_basic_statistics(text))
+
+        # Measurement statistics
+        stats.update(self.extract_measurement_statistics(text))
+
+        # Disease-specific statistics
+        stats.update(self.extract_disease_specific_statistics(text))
+
+        # Readability statistics
+        stats.update(self.extract_readability_statistics(text))
+
+        return stats
+
+    def get_feature_vector(self, text: str) -> np.ndarray:
+        """Convert statistics to feature vector"""
+        stats = self.extract_all_statistics(text)
+        return np.array(list(stats.values()))
+
+    def get_feature_names(self) -> List[str]:
+        """Get names of statistical features"""
+        # Extract features from a sample text to get all feature names
+        stats = self.extract_all_statistics("Sample text")
+        return list(stats.keys())
+
+
+# Example usage and testing
+if __name__ == "__main__":
+    # Test texts
+    test_texts = [
+        """Patient with ALS showing respiratory decline. FVC = 65% ± 5%. 
+           ALSFRS-R score decreased from 42 to 38 over 3 months.""",
+        """Subject with severe OCD symptoms. Y-BOCS score: 28. 
+           Treatment includes cognitive behavioral therapy with daily monitoring.""",
+        """Parkinson's disease patient showing increased tremor. UPDRS score of 45. 
+           Started on levodopa 100mg/day with three-month follow-up."""
+    ]
+
+    # Test for different disease categories
+    for disease in ['ALS', 'OCD', 'Parkinson']:
+        logger.info(f"\nAnalyzing {disease} text:")
+
+        # Create statistics extractor
+        extractor = TextStatisticsExtractor(disease_category=disease)
+
+        # Get relevant test text
+        text = test_texts[['ALS', 'OCD', 'Parkinson'].index(disease)]
+
+        # Extract all statistics
+        stats = extractor.extract_all_statistics(text)
+
+        # Print results
+        logger.info("\nText Statistics:")
+        for feature, value in stats.items():
+            logger.info(f"{feature}: {value:.4f}")
+
+        # Get feature vector
+        feature_vector = extractor.get_feature_vector(text)
+        logger.info(f"\nFeature vector shape: {feature_vector.shape}")
+
+        # Get feature names
+        feature_names = extractor.get_feature_names()
+        logger.info(f"Number of features: {len(feature_names)}")
\ No newline at end of file