--- a +++ b/src/features/text_statistics.py @@ -0,0 +1,199 @@ +import numpy as np +import re +from typing import List, Dict, Optional, Tuple +from collections import defaultdict +from src.preprocessing.preprocessing import create_ordered_medical_pipeline +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class TextStatisticsExtractor: + """Extract statistical features from medical texts based on EDA findings""" + + def __init__(self, disease_category: Optional[str] = None): + self.disease_category = disease_category + self.logger = get_logger(self.__class__.__name__) + + # Initialize preprocessing pipeline + self.preprocessor = create_ordered_medical_pipeline( + disease_category=disease_category + ) + + # Measurement patterns from EDA + self.measurement_patterns = { + 'scores': r'\d+(?:\s*(?:points?|score))', + 'percentages': r'\d+(?:\.\d+)?\s*%', + 'ranges': r'\d+\s*(?:-|to)\s*\d+', + 'units': r'\d+(?:\.\d+)?\s*(?:mg|kg|ml|cm|mm)', + 'plus_minus': r'\d+\s*±\s*\d+' + } + + # Disease-specific patterns from EDA + self.disease_patterns = { + 'ALS': { + 'scores': r'(?:ALSFRS-R|FVC)', + 'measurements': r'\d+\s*(?:fvc|alsfrs)', + 'time_patterns': r'(?:months?|years?)\s*(?:decline|progression)' + }, + 'OCD': { + 'scores': r'(?:Y-BOCS|severity)', + 'measurements': r'\d+\s*(?:ybocs|severity)', + 'time_patterns': r'(?:frequency|duration)\s*of\s*(?:symptoms|behaviors)' + }, + 'Parkinson': { + 'scores': r'(?:UPDRS|Hoehn)', + 'measurements': r'\d+\s*(?:updrs|stage)', + 'time_patterns': r'(?:onset|progression|duration)' + }, + 'Dementia': { + 'scores': r'(?:MMSE|CDR)', + 'measurements': r'\d+\s*(?:mmse|cdr)', + 'time_patterns': r'(?:months?|years?)\s*(?:decline|progression)' + }, + 'Scoliosis': { + 'scores': r'(?:Cobb|curve)', + 'measurements': r'\d+\s*(?:degree|angle)', + 'time_patterns': r'(?:growth|progression|correction)' + } + } + + def extract_basic_statistics(self, text: str) -> Dict[str, float]: + """Extract basic text statistics""" + # Preprocess text + processed = self.preprocessor.process(text) + if isinstance(processed, tuple): + processed = processed[0] + + words = processed.split() + sentences = [s.strip() for s in processed.split('.') if s.strip()] + + return { + 'word_count': len(words), + 'sentence_count': len(sentences), + 'avg_sentence_length': len(words) / len(sentences) if sentences else 0, + 'avg_word_length': sum(len(w) for w in words) / len(words) if words else 0 + } + + def extract_measurement_statistics(self, text: str) -> Dict[str, float]: + """Extract measurement-related statistics""" + stats = {} + + # Count measurements by type + for name, pattern in self.measurement_patterns.items(): + matches = re.finditer(pattern, text, re.IGNORECASE) + stats[f'{name}_count'] = sum(1 for _ in matches) + + # Calculate measurement density + total_measurements = sum(stats.values()) + words = text.split() + stats['measurement_density'] = total_measurements / len(words) if words else 0 + + return stats + + def extract_disease_specific_statistics(self, text: str) -> Dict[str, float]: + """Extract disease-specific statistics""" + stats = {} + + if self.disease_category and self.disease_category in self.disease_patterns: + patterns = self.disease_patterns[self.disease_category] + + for name, pattern in patterns.items(): + matches = re.finditer(pattern, text, re.IGNORECASE) + stats[f'{self.disease_category.lower()}_{name}_count'] = sum(1 for _ in matches) + + return stats + + def extract_readability_statistics(self, text: str) -> Dict[str, float]: + """Extract readability statistics""" + words = text.split() + sentences = [s.strip() for s in text.split('.') if s.strip()] + + # Count syllables (simple approximation) + def count_syllables(word): + return len(re.findall(r'[aeiou]+', word.lower())) + 1 + + syllable_counts = [count_syllables(w) for w in words] + + stats = { + 'avg_syllables_per_word': sum(syllable_counts) / len(words) if words else 0, + 'complex_words_ratio': sum(1 for c in syllable_counts if c > 2) / len(words) if words else 0 + } + + # Approximate Flesch Reading Ease + if words and sentences: + stats['flesch_reading_ease'] = 206.835 - 1.015 * (len(words) / len(sentences)) - 84.6 * ( + sum(syllable_counts) / len(words)) + else: + stats['flesch_reading_ease'] = 0 + + return stats + + def extract_all_statistics(self, text: str) -> Dict[str, float]: + """Extract all statistical features""" + # Collect all statistics + stats = {} + + # Basic statistics + stats.update(self.extract_basic_statistics(text)) + + # Measurement statistics + stats.update(self.extract_measurement_statistics(text)) + + # Disease-specific statistics + stats.update(self.extract_disease_specific_statistics(text)) + + # Readability statistics + stats.update(self.extract_readability_statistics(text)) + + return stats + + def get_feature_vector(self, text: str) -> np.ndarray: + """Convert statistics to feature vector""" + stats = self.extract_all_statistics(text) + return np.array(list(stats.values())) + + def get_feature_names(self) -> List[str]: + """Get names of statistical features""" + # Extract features from a sample text to get all feature names + stats = self.extract_all_statistics("Sample text") + return list(stats.keys()) + + +# Example usage and testing +if __name__ == "__main__": + # Test texts + test_texts = [ + """Patient with ALS showing respiratory decline. FVC = 65% ± 5%. + ALSFRS-R score decreased from 42 to 38 over 3 months.""", + """Subject with severe OCD symptoms. Y-BOCS score: 28. + Treatment includes cognitive behavioral therapy with daily monitoring.""", + """Parkinson's disease patient showing increased tremor. UPDRS score of 45. + Started on levodopa 100mg/day with three-month follow-up.""" + ] + + # Test for different disease categories + for disease in ['ALS', 'OCD', 'Parkinson']: + logger.info(f"\nAnalyzing {disease} text:") + + # Create statistics extractor + extractor = TextStatisticsExtractor(disease_category=disease) + + # Get relevant test text + text = test_texts[['ALS', 'OCD', 'Parkinson'].index(disease)] + + # Extract all statistics + stats = extractor.extract_all_statistics(text) + + # Print results + logger.info("\nText Statistics:") + for feature, value in stats.items(): + logger.info(f"{feature}: {value:.4f}") + + # Get feature vector + feature_vector = extractor.get_feature_vector(text) + logger.info(f"\nFeature vector shape: {feature_vector.shape}") + + # Get feature names + feature_names = extractor.get_feature_names() + logger.info(f"Number of features: {len(feature_names)}") \ No newline at end of file