--- a +++ b/src/preprocessing/preprocessing.py @@ -0,0 +1,1047 @@ +from abc import ABC, abstractmethod +from typing import List, Set, Dict, Optional, Union, Tuple +import re +import string +import logging +from dataclasses import dataclass +from collections import defaultdict +from nltk.corpus import stopwords as nltk_stopwords +from src.utils.logger import get_logger + +logger = get_logger(__name__) + +# --------------------------- +# Utility function +# --------------------------- +def _get_plain_text(text: Union[str, "ProcessingResult"]) -> str: + """Helper to extract plain text from a ProcessingResult or any object with a .text attribute.""" + if hasattr(text, "text"): + return text.text + if isinstance(text, str): + return text + return str(text) + +# --------------------------- +# Data Classes +# --------------------------- +@dataclass +class ProcessingResult: + """Data class to hold processing results and metrics""" + text: str + original_length: int + processed_length: int + special_chars_removed: Dict[str, int] + numbers_detected: List[str] + medical_terms_found: List[str] + abbreviations_normalized: Dict[str, str] + + def split(self): + """Splits the text into tokens""" + return self.text.split() + + def __len__(self): + """Returns the length of the text""" + return len(self.text) + +@dataclass +class PreprocessingStats: + """Statistics about the preprocessing operation""" + original_length: int + processed_length: int + special_chars_removed: int + numbers_found: int + medical_terms_preserved: int + +class PreprocessingContext: + """Context for preprocessing operations""" + def __init__(self): + self.preserved_terms: Set[str] = set() + self.stats = {} + self.original_text: str = "" + self.metadata: Dict = {} + +# --------------------------- +# Abstract Base Class +# --------------------------- +class TextPreprocessor(ABC): + """Abstract base class for text preprocessors with enhanced medical text capabilities""" + def __init__(self): + self.logger = get_logger(self.__class__.__name__) + # Medical abbreviations based on EDA findings + self.medical_abbreviations = { + 'ALS': 'amyotrophic lateral sclerosis', + 'OCD': 'obsessive compulsive disorder', + 'PD': "parkinson's disease", + 'AD': "alzheimer's disease", + 'CBT': 'cognitive behavioral therapy' + } + # Disease-specific terms to preserve (from n-gram analysis) + self.preserve_terms = { + 'lateral sclerosis', 'amyotrophic', 'respiratory', 'muscle strength', + 'compulsive', 'obsessive', 'anxiety', 'behavioral therapy', + 'motor', 'levodopa', 'dopaminergic', 'movement disorders', + 'cognitive', 'alzheimer', 'caregivers', 'memory', + 'spine', 'curve', 'surgical', 'idiopathic' + } + # Medical measurements (from special character analysis) + self.measurement_patterns = [ + r'\d+\s*mg', r'\d+\s*kg', r'\d+\s*ml', + r'\d+\s*cm', r'\d+\s*mm', r'\d+\s*units' + ] + # Important special characters to handle (based on EDA) + self.special_chars_handling = { + '(': 'preserve', # Important for medical context + ')': 'preserve', + '.': 'preserve_with_space', + ',': 'preserve_with_space', + '-': 'preserve_in_numbers', + '/': 'preserve_in_measurements', + '%': 'preserve_with_number', + '±': 'preserve_in_measurements' + } + + def _preserve_medical_terms(self, text: str) -> Tuple[str, List[str]]: + """Preserve important medical terms and their context""" + text = _get_plain_text(text) + preserved_terms = [] + for term in self.preserve_terms: + if term.lower() in text.lower(): + preserved_terms.append(term) + pattern = re.compile(re.escape(term), re.IGNORECASE) + text = pattern.sub(lambda m: m.group().lower(), text) + return text, preserved_terms + + def _handle_abbreviations(self, text: str) -> Tuple[str, Dict[str, str]]: + """Handle medical abbreviations consistently""" + text = _get_plain_text(text) + normalized = {} + for abbr, full_form in self.medical_abbreviations.items(): + if abbr in text: + normalized[abbr] = full_form + text = text.replace(abbr.lower(), abbr) + text = text.replace(abbr.upper(), abbr) + return text, normalized + + def _process_measurements(self, text: str) -> Tuple[str, List[str]]: + """Handle medical measurements consistently""" + text = _get_plain_text(text) + measurements_found = [] + for pattern in self.measurement_patterns: + matches = re.finditer(pattern, text, re.IGNORECASE) + for match in matches: + measurement = match.group() + measurements_found.append(measurement) + text = text.replace(measurement, measurement.lower().replace(' ', '')) + return text, measurements_found + + def _handle_special_characters(self, text: str) -> Tuple[str, Dict[str, int]]: + """ + Process special characters based on their context. + This improved version uses a replacement function to check boundaries. + """ + text = _get_plain_text(text) + # Count the original occurrences of each special char. + char_counts = {char: text.count(char) for char in self.special_chars_handling.keys()} + + # Process each special character according to its handling rule. + for char, handling in self.special_chars_handling.items(): + # For 'preserve' rule, do nothing. + if handling == 'preserve': + continue + + # For the other rules, define a replacement function: + def replacer(match): + # Fetch the matched character. + current = match.group() + start, end = match.start(), match.end() + # For hyphen ('-'), check if it lies between two digits or letters. + if char == '-' and start > 0 and end < len(text): + prev_char, next_char = text[start - 1], text[end] + # If between digits or alphanumerics, do not replace. + if prev_char.isdigit() and next_char.isdigit(): + return current + if prev_char.isalnum() and next_char.isalnum(): + return current + # Otherwise, add a space before and after. + # (You might want to adjust this if you need to preserve some punctuation) + return f' {current} ' + + # Adjust regex patterns based on rules. + if handling == 'preserve_with_space': + # Replace every occurrence of the char with space around it, only if not already spaced. + pattern = re.compile(re.escape(char)) + text = pattern.sub(replacer, text) + + elif handling in ['preserve_in_numbers', 'preserve_with_number']: + # For these, only replace if the char is *not* between digits. + # (Note: if it's between digits, leave it unmodified.) + pattern = re.compile(re.escape(char)) + text = pattern.sub(lambda m: m.group() if ( + m.start() > 0 and m.string[m.start() - 1].isdigit() and + m.end() < len(m.string) and m.string[m.end()].isdigit()) + else replacer(m), text) + + elif handling == 'preserve_in_measurements': + # For measurements, try not to change the char when it is adjacent to digits or whitespace. + pattern = re.compile(re.escape(char)) + text = pattern.sub(lambda m: m.group() if ( + m.start() > 0 and (m.string[m.start() - 1].isdigit() or m.string[m.start() - 1].isspace()) + and m.end() < len(m.string) and (m.string[m.end()].isdigit() or m.string[m.end()].isspace())) + else replacer(m), text) + + return text, char_counts + + def _standardize_whitespace(self, text: str) -> str: + """Standardize whitespace while preserving important formatting""" + text = _get_plain_text(text) + text = re.sub(r'\s+', ' ', text) + text = re.sub(r'([.,!?])(?=[^\s.,!?])', r'\1 ', text) + return text.strip() + + def _extract_numbers(self, text: str) -> List[str]: + """Extract and standardize numerical values""" + text = _get_plain_text(text) + number_patterns = [r'\d+\.\d+', r'\d+', r'\d+\s*-\s*\d+', r'\d+\/\d+'] + numbers_found = [] + for pattern in number_patterns: + matches = re.finditer(pattern, text) + numbers_found.extend(match.group() for match in matches) + return numbers_found + + def validate_text(self, text: str) -> bool: + """Validate text for non-empty content""" + text = _get_plain_text(text) + if not text or len(text.strip()) == 0: + return False + return True + +# --------------------------- +# Pipeline Classes +# --------------------------- +class PreprocessingPipeline: + """Enhanced preprocessing pipeline for medical text""" + def __init__(self): + self.steps: List[TextPreprocessor] = [] + self.logger = get_logger(__name__) + self.context = PreprocessingContext() + self.max_length: Optional[int] = 9000 + + def add_step(self, processor: TextPreprocessor) -> None: + self.steps.append(processor) + self.logger.info(f"Added preprocessing step: {processor.get_name()}") + + def set_max_length(self, length: int) -> None: + self.max_length = length + self.logger.info(f"Set maximum length to {length}") + + def add_preserved_terms(self, terms: Set[str]) -> None: + self.context.preserved_terms.update(terms) + + def process(self, text: str, collect_stats: bool = True) -> Union[str, Tuple[str, PreprocessingStats]]: + text = _get_plain_text(text) + self.context.original_text = text + processed_text = text + stats = PreprocessingStats( + original_length=len(text), + processed_length=0, + special_chars_removed=0, + numbers_found=0, + medical_terms_preserved=0 + ) + try: + for step in self.steps: + self.logger.debug(f"Applying {step.get_name()}") + result = step.process(processed_text) + # Always update processed_text with the output from the step + processed_text = _get_plain_text(result) + if self.max_length and len(processed_text) > self.max_length: + processed_text = processed_text[:self.max_length] + self.logger.info(f"Truncated text to {self.max_length} characters") + if collect_stats: + stats.processed_length = len(processed_text) + return processed_text, stats + return processed_text + except Exception as e: + self.logger.error(f"Error in preprocessing pipeline: {str(e)}") + raise + + def process_batch(self, texts: List[str]) -> List[str]: + return [self.process(text, collect_stats=False) for text in texts] + + def get_pipeline_info(self) -> Dict: + return { + 'steps': [step.get_name() for step in self.steps], + 'max_length': self.max_length, + 'preserved_terms': len(self.context.preserved_terms) + } + + def reset_context(self) -> None: + self.context = PreprocessingContext() + +# --------------------------- +# Preprocessing Step Implementations +# --------------------------- +class LengthNormalizer(TextPreprocessor): + """Handles text length normalization based on EDA findings""" + def __init__(self, max_length: int = 9000): + super().__init__() + self.max_length = max_length + + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + original_length = len(text) + if original_length > self.max_length: + truncated = text[:self.max_length] + last_period = truncated.rfind('.') + if last_period > 0: + text = text[:last_period + 1] + else: + text = truncated + return ProcessingResult( + text=text, + original_length=original_length, + processed_length=len(text), + special_chars_removed={}, + numbers_detected=[], + medical_terms_found=[], + abbreviations_normalized={} + ) + + def get_name(self) -> str: + return "length_normalizer" + +class MedicalTermPreprocessor(TextPreprocessor): + """Handles medical terminology and abbreviations""" + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + text, abbreviations = self._handle_abbreviations(text) + text, preserved_terms = self._preserve_medical_terms(text) + text, measurements = self._process_measurements(text) + return ProcessingResult( + text=text, + original_length=len(text), + processed_length=len(text), + special_chars_removed={}, + numbers_detected=measurements, + medical_terms_found=preserved_terms, + abbreviations_normalized=abbreviations + ) + + def get_name(self) -> str: + return "medical_term_processor" + +class SpecialCharacterHandler(TextPreprocessor): + """Handles special characters based on EDA findings""" + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + text, char_counts = self._handle_special_characters(text) + numbers = self._extract_numbers(text) + return ProcessingResult( + text=text, + original_length=len(text), + processed_length=len(text), + special_chars_removed=char_counts, + numbers_detected=numbers, + medical_terms_found=[], + abbreviations_normalized={} + ) + + def get_name(self) -> str: + return "special_character_handler" + +class DiseaseCategoryPreprocessor(TextPreprocessor): + """Disease-specific preprocessing based on EDA findings""" + def __init__(self, disease_category: str): + super().__init__() + self.disease_category = disease_category + self.patterns = { + 'ALS': { + 'terms': ['amyotrophic lateral sclerosis', 'motor function', 'respiratory'], + 'measurements': [r'\d+\s*fvc', r'\d+\s*alsfrs'], + }, + 'OCD': { + 'terms': ['obsessive compulsive', 'anxiety', 'behavioral therapy'], + 'measurements': [r'\d+\s*ybocs', r'\d+\s*severity'], + }, + 'Parkinson': { + 'terms': ['motor symptoms', 'levodopa', 'dopaminergic'], + 'measurements': [r'\d+\s*updrs', r'\d+\s*hoehn'], + }, + 'Dementia': { + 'terms': ['cognitive function', 'memory', 'alzheimer'], + 'measurements': [r'\d+\s*mmse', r'\d+\s*cdr'], + }, + 'Scoliosis': { + 'terms': ['spinal curve', 'surgical correction', 'idiopathic'], + 'measurements': [r'\d+\s*degree', r'\d+\s*cobb'], + } + } + + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + disease_patterns = self.patterns.get(self.disease_category, {}) + preserved_terms = [] + for term in disease_patterns.get('terms', []): + if term.lower() in text.lower(): + preserved_terms.append(term) + pattern = re.compile(re.escape(term), re.IGNORECASE) + text = pattern.sub(lambda m: m.group().lower(), text) + measurements_found = [] + for pattern in disease_patterns.get('measurements', []): + matches = re.finditer(pattern, text, re.IGNORECASE) + measurements_found.extend(match.group() for match in matches) + return ProcessingResult( + text=text, + original_length=len(text), + processed_length=len(text), + special_chars_removed={}, + numbers_detected=measurements_found, + medical_terms_found=preserved_terms, + abbreviations_normalized={} + ) + + def get_name(self) -> str: + return f"{self.disease_category.lower()}_preprocessor" + +class MedicalScorePreprocessor(TextPreprocessor): + """Handles disease-specific medical scores based on EDA findings""" + def __init__(self): + super().__init__() + self.score_patterns = { + 'ALS': { + 'FVC': r'\d+\s*%?\s*FVC', + 'ALSFRS': r'ALSFRS-?R?\s*(?:score)?\s*:?\s*\d+', + }, + 'OCD': { + 'YBOCS': r'Y-?BOCS\s*(?:score)?\s*:?\s*\d+', + 'Severity': r'severity\s*(?:score)?\s*:?\s*\d+', + }, + 'Parkinson': { + 'UPDRS': r'UPDRS\s*(?:score)?\s*:?\s*\d+', + 'Hoehn': r'Hoehn\s*(?:and)?\s*Yahr\s*(?:stage)?\s*:?\s*\d+', + }, + 'Dementia': { + 'MMSE': r'MMSE\s*(?:score)?\s*:?\s*\d+', + 'CDR': r'CDR\s*(?:score)?\s*:?\s*[\d\.]+', + }, + 'Scoliosis': { + 'Cobb': r'Cobb\s*(?:angle)?\s*:?\s*\d+\s*°?', + 'Degree': r'\d+\s*(?:degree[s]?)?\s*(?:curve|angle)', + } + } + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + scores_found = [] + for disease, patterns in self.score_patterns.items(): + for score_type, pattern in patterns.items(): + matches = re.finditer(pattern, text, re.IGNORECASE) + for match in matches: + score = match.group() + scores_found.append(f"{disease}_{score_type}: {score}") + standardized = re.sub(r'\s+', ' ', score).strip() + text = text.replace(score, standardized) + return ProcessingResult( + text=text, + original_length=len(text), + processed_length=len(text), + special_chars_removed={}, + numbers_detected=scores_found, + medical_terms_found=[], + abbreviations_normalized={} + ) + def get_name(self) -> str: + return "medical_score_processor" + +class MedicalAbbreviationNormalizer(TextPreprocessor): + """Specialized processor for handling medical abbreviations based on EDA findings""" + def __init__(self): + super().__init__() + self.disease_abbreviations = { + 'ALS': 'amyotrophic lateral sclerosis', + 'OCD': 'obsessive compulsive disorder', + 'PD': "parkinson's disease", + 'AD': "alzheimer's disease", + } + self.measurement_abbreviations = { + 'mg': 'milligrams', + 'kg': 'kilograms', + 'ml': 'milliliters', + 'mm': 'millimeters', + 'cm': 'centimeters', + } + self.score_abbreviations = { + 'ALSFRS-R': 'ALS Functional Rating Scale Revised', + 'FVC': 'Forced Vital Capacity', + 'YBOCS': 'Yale-Brown Obsessive Compulsive Scale', + 'UPDRS': 'Unified Parkinson Disease Rating Scale', + 'MMSE': 'Mini-Mental State Examination', + 'CDR': 'Clinical Dementia Rating', + } + self.procedure_abbreviations = { + 'MRI': 'Magnetic Resonance Imaging', + 'CT': 'Computed Tomography', + 'EEG': 'Electroencephalogram', + 'EMG': 'Electromyography', + } + self.medical_term_abbreviations = { + 'tx': 'treatment', + 'dx': 'diagnosis', + 'hx': 'history', + 'pts': 'patients', + 'sig': 'significant', + } + def _create_abbreviation_pattern(self) -> str: + all_abbrevs = set() + all_abbrevs.update(self.disease_abbreviations.keys()) + all_abbrevs.update(self.score_abbreviations.keys()) + all_abbrevs.update(self.procedure_abbreviations.keys()) + all_abbrevs.update(self.medical_term_abbreviations.keys()) + sorted_abbrevs = sorted(all_abbrevs, key=len, reverse=True) + pattern = '|'.join(re.escape(abbr) for abbr in sorted_abbrevs) + return f'\\b({pattern})\\b' + def _detect_context(self, text: str, abbr: str, window: int = 50) -> bool: + text = _get_plain_text(text) + abbr_idx = text.find(abbr) + if abbr_idx == -1: + return True + start = max(0, abbr_idx - window) + end = min(len(text), abbr_idx + len(abbr) + window) + context = text[start:end].lower() + full_form = self._get_full_form(abbr).lower() + if full_form in context: + return False + return True + def _get_full_form(self, abbr: str) -> str: + for abbrev_dict in [self.disease_abbreviations, self.score_abbreviations, + self.procedure_abbreviations, self.medical_term_abbreviations]: + if abbr in abbrev_dict: + return abbrev_dict[abbr] + return abbr + def _normalize_measurements(self, text: str) -> Tuple[str, Dict[str, List[str]]]: + text = _get_plain_text(text) + normalized = text + found_measurements = defaultdict(list) + for abbr, full_form in self.measurement_abbreviations.items(): + pattern = f'(\\d+(?:\\.\\d+)?)\\s*{abbr}\\b' + matches = re.finditer(pattern, text, re.IGNORECASE) + for match in matches: + value = match.group(1) + found_measurements[abbr].append(value) + normalized = normalized.replace(match.group(), f"{value} {abbr}") + return normalized, dict(found_measurements) + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + original_text = text + abbreviations_found = {} + text, measurements = self._normalize_measurements(text) + pattern = self._create_abbreviation_pattern() + for match in re.finditer(pattern, text, re.IGNORECASE): + abbr = match.group() + if not self._detect_context(text, abbr): + continue + full_form = self._get_full_form(abbr) + abbreviations_found[abbr] = full_form + first_idx = text.find(abbr) + if first_idx >= 0: + replacement = f"{full_form} ({abbr})" + text = text[:first_idx] + replacement + text[first_idx + len(abbr):] + text = text.replace(abbr, abbr) + for abbr, values in measurements.items(): + abbreviations_found[abbr] = f"{self.measurement_abbreviations[abbr]} (found with values: {', '.join(values)})" + return ProcessingResult( + text=text, + original_length=len(original_text), + processed_length=len(text), + special_chars_removed={}, + numbers_detected=[], + medical_terms_found=[], + abbreviations_normalized=abbreviations_found + ) + def get_name(self) -> str: + return "medical_abbreviation_normalizer" + +class MedicalTermStandardizer(TextPreprocessor): + """Standardizes medical terminology based on EDA findings and disease categories""" + def __init__(self): + super().__init__() + self.disease_terms = { + 'ALS': { + 'motor neuron disease': 'amyotrophic lateral sclerosis', + 'muscle weakness': 'muscular weakness', + 'breathing difficulty': 'respiratory difficulty', + 'bulbar symptoms': 'bulbar dysfunction', + 'als disease': 'ALS', + 'motor function loss': 'motor deficit', + }, + 'OCD': { + 'obsessive thoughts': 'obsessions', + 'compulsive behaviors': 'compulsions', + 'intrusive thoughts': 'obsessive thoughts', + 'ritualistic behavior': 'compulsive behavior', + 'anxiety symptoms': 'anxiety manifestations', + 'ocd symptoms': 'OCD manifestations', + }, + 'Parkinson': { + 'tremors': 'tremor', + 'shaking': 'tremor', + 'movement problems': 'motor symptoms', + 'balance problems': 'postural instability', + 'stiffness': 'rigidity', + 'slow movement': 'bradykinesia', + }, + 'Dementia': { + 'memory loss': 'cognitive decline', + 'confusion': 'cognitive impairment', + 'behavioral changes': 'behavioral symptoms', + 'memory problems': 'memory impairment', + 'thinking problems': 'cognitive dysfunction', + 'mental decline': 'cognitive deterioration', + }, + 'Scoliosis': { + 'spine curvature': 'spinal curvature', + 'curved spine': 'spinal curvature', + 'back deformity': 'spinal deformity', + 'spine deviation': 'spinal deviation', + 'backbone curve': 'spinal curve', + 'spinal bend': 'spinal curvature', + } + } + self.common_terms = { + 'symptom': 'manifestation', + 'side effect': 'adverse effect', + 'adverse reaction': 'adverse effect', + 'negative effect': 'adverse effect', + 'medicine': 'medication', + 'drug': 'medication', + 'therapy': 'treatment', + 'therapeutic': 'treatment', + 'test': 'assessment', + 'examination': 'assessment', + 'evaluation': 'assessment', + 'checkup': 'examination', + 'condition': 'status', + 'state': 'status', + 'progression': 'disease progression', + 'improvement': 'clinical improvement', + 'deterioration': 'clinical deterioration' + } + self.measurement_terms = { + 'monthly': 'per month', + 'weekly': 'per week', + 'daily': 'per day', + 'yearly': 'per year', + 'twice daily': 'BID', + 'three times daily': 'TID', + 'four times daily': 'QID', + 'once daily': 'QD', + 'milliliters': 'ml', + 'milligrams': 'mg', + 'kilograms': 'kg', + 'centimeters': 'cm' + } + def _create_term_pattern(self, terms: Dict[str, str]) -> str: + sorted_terms = sorted(terms.keys(), key=len, reverse=True) + pattern = '|'.join(re.escape(term) for term in sorted_terms) + return f'\\b({pattern})\\b' + def _standardize_disease_terms(self, text: str, disease: str) -> Tuple[str, List[str]]: + text = _get_plain_text(text) + standardized_terms = [] + if disease in self.disease_terms: + pattern = self._create_term_pattern(self.disease_terms[disease]) + def replace_term(match): + term = match.group(0) + standardized = self.disease_terms[disease].get(term.lower(), term) + standardized_terms.append(f"{term} → {standardized}") + return standardized + text = re.sub(pattern, replace_term, text, flags=re.IGNORECASE) + return text, standardized_terms + def _standardize_common_terms(self, text: str) -> Tuple[str, List[str]]: + text = _get_plain_text(text) + standardized_terms = [] + pattern = self._create_term_pattern(self.common_terms) + def replace_term(match): + term = match.group(0) + standardized = self.common_terms.get(term.lower(), term) + standardized_terms.append(f"{term} → {standardized}") + return standardized + text = re.sub(pattern, replace_term, text, flags=re.IGNORECASE) + return text, standardized_terms + def _standardize_measurements(self, text: str) -> Tuple[str, List[str]]: + text = _get_plain_text(text) + standardized_terms = [] + pattern = self._create_term_pattern(self.measurement_terms) + def replace_term(match): + term = match.group(0) + standardized = self.measurement_terms.get(term.lower(), term) + standardized_terms.append(f"{term} → {standardized}") + return standardized + text = re.sub(pattern, replace_term, text, flags=re.IGNORECASE) + return text, standardized_terms + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + original_text = text + all_standardized_terms = [] + for disease in self.disease_terms.keys(): + text, disease_terms = self._standardize_disease_terms(text, disease) + all_standardized_terms.extend(disease_terms) + text, common_terms = self._standardize_common_terms(text) + all_standardized_terms.extend(common_terms) + text, measurement_terms = self._standardize_measurements(text) + all_standardized_terms.extend(measurement_terms) + return ProcessingResult( + text=text, + original_length=len(original_text), + processed_length=len(text), + special_chars_removed={}, + numbers_detected=[], + medical_terms_found=all_standardized_terms, + abbreviations_normalized={} + ) + def get_name(self) -> str: + return "medical_term_standardizer" + +class DomainSpecificStopwordHandler(TextPreprocessor): + """ + Handles stopwords specifically for medical trial descriptions, + preserving domain-critical terms based on EDA findings + """ + def __init__(self): + super().__init__() + self.standard_stopwords = set(nltk_stopwords.words('english')) + self.preserve_terms = { + 'disease', 'disorder', 'syndrome', 'condition', 'symptoms', + 'treatment', 'therapy', 'drug', 'dose', 'trial', + 'high', 'low', 'increase', 'decrease', 'level', + 'daily', 'weekly', 'monthly', 'duration', + 'positive', 'negative', 'normal', 'abnormal' + } + self.domain_stopwords = { + 'study', 'research', 'clinical', 'medical', 'patient', + 'subject', 'participant', 'investigator', 'physician', + 'protocol', 'procedure', 'visit', 'center', 'facility', + 'including', 'included', 'excluded', 'following', + 'based', 'related', 'associated', 'regarding' + } + self.disease_specific_terms = { + 'ALS': {'respiratory', 'bulbar', 'motor', 'muscle', 'weakness', + 'progression', 'function', 'strength', 'vital', 'capacity'}, + 'OCD': {'compulsive', 'obsessive', 'anxiety', 'behavior', + 'ritual', 'intrusive', 'thoughts', 'severity'}, + 'Parkinson': {'motor', 'tremor', 'rigidity', 'movement', 'balance', + 'gait', 'dopamine', 'dyskinesia'}, + 'Dementia': {'cognitive', 'memory', 'mental', 'behavioral', + 'function', 'decline', 'caregiver', 'activities'}, + 'Scoliosis': {'curve', 'spine', 'degree', 'correction', 'fusion', + 'thoracic', 'lumbar', 'surgical'} + } + def _is_measurement_or_value(self, token: str) -> bool: + measurement_patterns = [ + r'\d+', r'\d+\.\d+', r'\d+%', r'\d+mg', r'\d+kg', r'\d+ml', r'\d+cm' + ] + return any(re.match(pattern, token) for pattern in measurement_patterns) + def _is_medical_abbreviation(self, token: str) -> bool: + abbreviation_patterns = [ + r'^[A-Z]{2,}$', + r'^[A-Z]{2,}-[A-Z]$', + r'^[A-Z][a-z]+[A-Z]+$' + ] + return any(re.match(pattern, token) for pattern in abbreviation_patterns) + def _should_preserve(self, token: str, disease_category: Optional[str] = None) -> bool: + if self._is_measurement_or_value(token) or self._is_medical_abbreviation(token): + return True + token_lower = token.lower() + if token_lower in self.preserve_terms: + return True + if disease_category and token_lower in self.disease_specific_terms.get(disease_category, set()): + return True + return False + def process(self, text: str, disease_category: Optional[str] = None) -> ProcessingResult: + text = _get_plain_text(text) + original_text = text + preserved_terms = [] + removed_terms = [] + tokens = text.split() + processed_tokens = [] + for token in tokens: + if self._should_preserve(token, disease_category): + processed_tokens.append(token) + preserved_terms.append(token) + elif token.lower() in self.domain_stopwords or token.lower() in self.standard_stopwords: + removed_terms.append(token) + else: + processed_tokens.append(token) + processed_text = ' '.join(processed_tokens) + return ProcessingResult( + text=processed_text, + original_length=len(original_text), + processed_length=len(processed_text), + special_chars_removed={}, + numbers_detected=[], + medical_terms_found=preserved_terms, + abbreviations_normalized={'removed_stopwords': removed_terms} + ) + def get_name(self) -> str: + return "domain_specific_stopword_handler" + +class SpecialCharacterCleaner(TextPreprocessor): + """ + Specialized cleaner for handling special characters in medical texts + based on EDA findings + """ + def __init__(self): + super().__init__() + self.char_categories = { + 'preserve_always': {'.', ',', '(', ')', '%', '±', '-', '/'}, + 'preserve_with_numbers': {'.', '-', '/', '%', '±'}, + 'preserve_in_compounds': {'-', '/'}, + 'replace_with_space': {';', ':', '|', '\\', '[', ']', '{', '}', '_', '=', '+', '*'} + } + self.medical_patterns = { + 'measurements': [ + r'\d+(?:\.\d+)?\s*(?:mg|kg|ml|cm|mm)', + r'\d+(?:\.\d+)?\s*%', + r'\d+(?:\.\d+)?\s*±\s*\d+(?:\.\d+)?' + ], + 'ranges': [ + r'\d+\s*-\s*\d+', + r'\d+\s*/\s*\d+' + ], + 'compound_terms': [ + r'\w+(?:-\w+)+', + r'\w+(?:/\w+)+' + ], + 'statistical_values': [ + r'p\s*[<>]\s*0\.\d+', + r'[\+-]\s*\d+(?:\.\d+)?' + ] + } + + def _identify_protected_spans(self, text: str) -> List[Tuple[int, int, str]]: + text = _get_plain_text(text) + protected_spans = [] + for category, patterns in self.medical_patterns.items(): + for pattern in patterns: + for match in re.finditer(pattern, text): + protected_spans.append((match.start(), match.end(), category)) + return sorted(protected_spans, key=lambda x: x[0]) + + def _clean_special_chars(self, text: str, protected_spans: List[Tuple[int, int, str]]) -> Tuple[str, Dict[str, int]]: + text = _get_plain_text(text) + cleaned_text = "" + last_end = 0 + char_counts = defaultdict(int) + for start, end, span_type in protected_spans: + segment = text[last_end:start] + cleaned_segment = self._process_segment(segment, char_counts) + cleaned_text += cleaned_segment + cleaned_text += text[start:end] + last_end = end + if last_end < len(text): + final_segment = text[last_end:] + cleaned_text += self._process_segment(final_segment, char_counts) + return cleaned_text, dict(char_counts) + + def _process_segment(self, segment: str, char_counts: Dict[str, int]) -> str: + segment = _get_plain_text(segment) + # Instead of a fixed process, use our unified replacement function for "replace_with_space" characters. + for ch in self.char_categories['replace_with_space']: + # Count occurrences first. + char_counts[ch] += segment.count(ch) + # Add spaces around the character if not already spaced. + segment = re.sub(rf'(?<!\s){re.escape(ch)}(?!\s)', f' {ch} ', segment) + # Collapse multiple spaces. + segment = re.sub(r'\s+', ' ', segment) + return segment + + def _standardize_measurements(self, text: str) -> str: + text = _get_plain_text(text) + standardized = text + for pattern in self.medical_patterns['measurements']: + standardized = re.sub( + pattern, + lambda m: m.group().replace(' ', ''), + standardized + ) + for pattern in self.medical_patterns['ranges']: + standardized = re.sub( + pattern, + lambda m: m.group().replace(' ', ''), + standardized + ) + return standardized + + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + original_text = text + protected_spans = self._identify_protected_spans(text) + cleaned_text, char_counts = self._clean_special_chars(text, protected_spans) + final_text = self._standardize_measurements(cleaned_text) + preserved_patterns = [f"{span_type}: {text[start:end]}" for start, end, span_type in protected_spans] + return ProcessingResult( + text=final_text, + original_length=len(original_text), + processed_length=len(final_text), + special_chars_removed=char_counts, + numbers_detected=[], + medical_terms_found=preserved_patterns, + abbreviations_normalized={} + ) + + def get_name(self) -> str: + return "special_character_cleaner" + +class LowercasePreprocessor(TextPreprocessor): + """Convert text to lowercase""" + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + lowercased_text = text.lower() + return ProcessingResult( + text=lowercased_text, + original_length=len(text), + processed_length=len(lowercased_text), + special_chars_removed={}, + numbers_detected=[], + medical_terms_found=[], + abbreviations_normalized={} + ) + def get_name(self) -> str: + return "lowercase" + +class WhitespaceNormalizer(TextPreprocessor): + """Normalize whitespace in text""" + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + normalized_text = ' '.join(text.split()) + return ProcessingResult( + text=normalized_text, + original_length=len(text), + processed_length=len(normalized_text), + special_chars_removed={}, + numbers_detected=[], + medical_terms_found=[], + abbreviations_normalized={} + ) + def get_name(self) -> str: + return "whitespace_normalizer" + +class MedicalTokenizer(TextPreprocessor): + """Specialized tokenizer for medical text that preserves important terms""" + def __init__(self): + super().__init__() + self.term_boundaries = { + 'measurements': r'\d+\s*(?:mg|kg|ml|cm|mm)', + 'scores': r'\d+\s*(?:points?|score)', + 'percentages': r'\d+\s*%', + 'ranges': r'\d+\s*-\s*\d+', + } + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + preserved = {} + for term_type, pattern in self.term_boundaries.items(): + matches = re.finditer(pattern, text, re.IGNORECASE) + for i, match in enumerate(matches): + placeholder = f"__{term_type}_{i}__" + preserved[placeholder] = match.group() + text = text.replace(match.group(), placeholder) + tokens = text.split() + processed_text = ' '.join(tokens) + for placeholder, original in preserved.items(): + processed_text = processed_text.replace(placeholder, original) + return ProcessingResult( + text=processed_text, + original_length=len(text), + processed_length=len(processed_text), + special_chars_removed={}, + numbers_detected=list(preserved.values()), + medical_terms_found=[], + abbreviations_normalized={} + ) + def get_name(self) -> str: + return "medical_tokenizer" + +# --------------------------- +# Pipeline Builder +# --------------------------- +def create_ordered_medical_pipeline( + disease_category: Optional[str] = None, + config: Optional[Dict] = None +) -> PreprocessingPipeline: + """Create a properly ordered preprocessing pipeline with improved text handling.""" + pipeline = PreprocessingPipeline() + default_config = { + 'max_length': 9000, + 'preserve_case': True, + 'include_scores': True, + 'standardize_terms': True, + 'handle_stopwords': False, + 'preserve_measurements': True + } + config = {**default_config, **(config or {})} + + class ImprovedSpecialCharacterCleaner(SpecialCharacterCleaner): + def _standardize_measurements(self, text: str) -> str: + standardized = text + for pattern in self.medical_patterns['measurements']: + standardized = re.sub( + pattern, + lambda m: re.sub(r'(\d+)\s*([a-zA-Z%]+)', r'\1 \2', m.group()), + standardized + ) + standardized = re.sub( + r'(\d+)\s*-\s*(\d+)', + r'\1 - \2', + standardized + ) + standardized = re.sub(r'\s*([±])\s*', r' \1 ', standardized) + return standardized + + class ImprovedMedicalTokenizer(MedicalTokenizer): + def process(self, text: str) -> ProcessingResult: + text = _get_plain_text(text) + preserved = {} + for term_type, pattern in self.term_boundaries.items(): + matches = re.finditer(pattern, text, re.IGNORECASE) + for i, match in enumerate(matches): + placeholder = f"__{term_type}_{i}__" + preserved[placeholder] = match.group() + text = text.replace(match.group(), placeholder) + text = re.sub(r'([A-Z]-[A-Z])', r' \1 ', text) + text = re.sub(r'([A-Z]+)([a-z])', r'\1 \2', text) + text = ' '.join(text.split()) + for placeholder, original in preserved.items(): + text = text.replace(placeholder, original) + return ProcessingResult( + text=text, + original_length=len(text), + processed_length=len(text), + special_chars_removed={}, + numbers_detected=list(preserved.values()), + medical_terms_found=[], + abbreviations_normalized={} + ) + + # 1. Initial Text Cleanup + pipeline.add_step(WhitespaceNormalizer()) + # 2. Length Normalization + pipeline.add_step(LengthNormalizer(max_length=config['max_length'])) + # 3. Improved Special Character Processing + pipeline.add_step(ImprovedSpecialCharacterCleaner()) + # 4. Medical Content Processing + if config['include_scores']: + pipeline.add_step(MedicalScorePreprocessor()) + pipeline.add_step(MedicalAbbreviationNormalizer()) + # 5. Term Standardization + if config['standardize_terms']: + pipeline.add_step(MedicalTermStandardizer()) + # 6. Disease-Specific Processing + if disease_category: + pipeline.add_step(DiseaseCategoryPreprocessor(disease_category)) + # 7. Stop Words + if config['handle_stopwords']: + pipeline.add_step(DomainSpecificStopwordHandler()) + # 8. Lowercase (if not preserving case) + if not config.get('preserve_case', True): + pipeline.add_step(LowercasePreprocessor()) + # 9. Final Tokenization with Improved Handling + pipeline.add_step(ImprovedMedicalTokenizer()) + return pipeline \ No newline at end of file