|
a |
|
b/main.py |
|
|
1 |
from flask import Flask, jsonify, request |
|
|
2 |
import numpy as np |
|
|
3 |
import joblib |
|
|
4 |
import traceback |
|
|
5 |
from pathlib import Path |
|
|
6 |
from src.preprocessing.preprocessing import create_ordered_medical_pipeline |
|
|
7 |
from src.utils.logger import get_logger |
|
|
8 |
from typing import Dict, Any |
|
|
9 |
|
|
|
10 |
# Initialize Flask app |
|
|
11 |
app = Flask(__name__) |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class PredictionService: |
|
|
15 |
"""Service for handling model predictions""" |
|
|
16 |
|
|
|
17 |
def __init__(self, model_path: str = 'src/models/model.joblib'): |
|
|
18 |
self.model_path = Path(model_path) |
|
|
19 |
self.pipeline = None |
|
|
20 |
self.preprocessor = None |
|
|
21 |
self.logger = get_logger(__name__) |
|
|
22 |
try: |
|
|
23 |
self.initialize() |
|
|
24 |
except Exception as e: |
|
|
25 |
error_traceback = traceback.format_exc() |
|
|
26 |
self.logger.error(f"Initialization failed with error: {str(e)}") |
|
|
27 |
self.logger.error(f"Traceback: {error_traceback}") |
|
|
28 |
raise |
|
|
29 |
|
|
|
30 |
def initialize(self): |
|
|
31 |
"""Initialize model pipeline""" |
|
|
32 |
try: |
|
|
33 |
self.logger.info(f"Looking for model at: {self.model_path.absolute()}") |
|
|
34 |
|
|
|
35 |
# Load model pipeline |
|
|
36 |
if not self.model_path.exists(): |
|
|
37 |
raise FileNotFoundError(f"Model file not found at {self.model_path.absolute()}") |
|
|
38 |
|
|
|
39 |
# Load the pipeline |
|
|
40 |
self.pipeline = joblib.load(self.model_path) |
|
|
41 |
self.logger.info("Model pipeline loaded successfully") |
|
|
42 |
|
|
|
43 |
# Log pipeline contents |
|
|
44 |
self.logger.info("Pipeline contents:") |
|
|
45 |
for key in self.pipeline.keys(): |
|
|
46 |
self.logger.info(f"- Found component: {key}") |
|
|
47 |
|
|
|
48 |
# Initialize preprocessor |
|
|
49 |
self.logger.info("Initializing preprocessor...") |
|
|
50 |
self.preprocessor = create_ordered_medical_pipeline() |
|
|
51 |
self.logger.info("Preprocessor initialized successfully") |
|
|
52 |
|
|
|
53 |
# Log feature dimensions |
|
|
54 |
self.logger.info(f"Expected feature dimension: {self.pipeline['feature_dim']}") |
|
|
55 |
self.logger.info(f"Vectorizer vocabulary size: {len(self.pipeline['vectorizer'].vocabulary_)}") |
|
|
56 |
|
|
|
57 |
except Exception as e: |
|
|
58 |
error_traceback = traceback.format_exc() |
|
|
59 |
self.logger.error(f"Initialization failed with error: {str(e)}") |
|
|
60 |
self.logger.error(f"Traceback: {error_traceback}") |
|
|
61 |
raise |
|
|
62 |
|
|
|
63 |
def predict(self, text: str) -> Dict[str, Any]: |
|
|
64 |
"""Make prediction on input text""" |
|
|
65 |
try: |
|
|
66 |
# Validate input |
|
|
67 |
if not isinstance(text, str): |
|
|
68 |
raise ValueError("Input must be a string") |
|
|
69 |
if not text.strip(): |
|
|
70 |
raise ValueError("Input text cannot be empty") |
|
|
71 |
|
|
|
72 |
# Preprocess text |
|
|
73 |
processed_text = self.preprocessor.process(text) |
|
|
74 |
if isinstance(processed_text, tuple): |
|
|
75 |
processed_text = processed_text[0] |
|
|
76 |
|
|
|
77 |
# Extract features using vectorizer |
|
|
78 |
features = self.pipeline['vectorizer'].transform([processed_text]) |
|
|
79 |
features = features.toarray() |
|
|
80 |
|
|
|
81 |
# Verify feature dimension |
|
|
82 |
if features.shape[1] != self.pipeline['feature_dim']: |
|
|
83 |
raise ValueError( |
|
|
84 |
f"Feature dimension mismatch: got {features.shape[1]}, " |
|
|
85 |
f"expected {self.pipeline['feature_dim']}" |
|
|
86 |
) |
|
|
87 |
|
|
|
88 |
# Scale features |
|
|
89 |
features_scaled = self.pipeline['scaler'].transform(features) |
|
|
90 |
|
|
|
91 |
# Get the model - the model itself is a VotingClassifier |
|
|
92 |
model = self.pipeline['model'] |
|
|
93 |
|
|
|
94 |
# Make prediction |
|
|
95 |
prediction = model.predict(features_scaled)[0] |
|
|
96 |
|
|
|
97 |
# Get prediction probability if available |
|
|
98 |
confidence = None |
|
|
99 |
if hasattr(model, 'predict_proba'): |
|
|
100 |
probabilities = model.predict_proba(features_scaled)[0] |
|
|
101 |
confidence = float(np.max(probabilities)) |
|
|
102 |
|
|
|
103 |
# Convert prediction using label encoder if available |
|
|
104 |
if 'metadata' in self.pipeline and 'label_encoder' in self.pipeline['metadata']: |
|
|
105 |
prediction = self.pipeline['metadata']['label_encoder'].inverse_transform([prediction])[0] |
|
|
106 |
|
|
|
107 |
return { |
|
|
108 |
'status': 'success', |
|
|
109 |
'prediction': prediction, |
|
|
110 |
'confidence': confidence |
|
|
111 |
} |
|
|
112 |
|
|
|
113 |
except Exception as e: |
|
|
114 |
error_traceback = traceback.format_exc() |
|
|
115 |
self.logger.error(f"Prediction failed: {str(e)}") |
|
|
116 |
self.logger.error(f"Traceback: {error_traceback}") |
|
|
117 |
raise |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
# Initialize service with error handling |
|
|
121 |
try: |
|
|
122 |
prediction_service = PredictionService() |
|
|
123 |
app.logger.info("PredictionService initialized successfully") |
|
|
124 |
except Exception as e: |
|
|
125 |
app.logger.error(f"Failed to initialize PredictionService: {str(e)}") |
|
|
126 |
traceback.print_exc() |
|
|
127 |
prediction_service = None |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
@app.route("/health") |
|
|
131 |
def health_check(): |
|
|
132 |
"""Health check endpoint""" |
|
|
133 |
if prediction_service is None: |
|
|
134 |
return jsonify({ |
|
|
135 |
'status': 'unhealthy', |
|
|
136 |
'error': 'Prediction service failed to initialize' |
|
|
137 |
}), 500 |
|
|
138 |
|
|
|
139 |
health_info = { |
|
|
140 |
'status': 'healthy', |
|
|
141 |
'components': { |
|
|
142 |
'model': prediction_service.pipeline is not None and 'model' in prediction_service.pipeline, |
|
|
143 |
'vectorizer': prediction_service.pipeline is not None and 'vectorizer' in prediction_service.pipeline, |
|
|
144 |
'scaler': prediction_service.pipeline is not None and 'scaler' in prediction_service.pipeline, |
|
|
145 |
'preprocessor': prediction_service.preprocessor is not None |
|
|
146 |
} |
|
|
147 |
} |
|
|
148 |
return jsonify(health_info) |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
@app.route("/predict", methods=["POST"]) |
|
|
152 |
def predict(): |
|
|
153 |
"""Prediction endpoint""" |
|
|
154 |
if prediction_service is None: |
|
|
155 |
return jsonify({ |
|
|
156 |
'status': 'error', |
|
|
157 |
'message': 'Prediction service is not available' |
|
|
158 |
}), 503 |
|
|
159 |
|
|
|
160 |
try: |
|
|
161 |
# Get request data |
|
|
162 |
data = request.get_json() |
|
|
163 |
|
|
|
164 |
# Validate request data |
|
|
165 |
if not data: |
|
|
166 |
return jsonify({ |
|
|
167 |
'status': 'error', |
|
|
168 |
'message': 'No data provided' |
|
|
169 |
}), 400 |
|
|
170 |
|
|
|
171 |
if 'description' not in data: |
|
|
172 |
return jsonify({ |
|
|
173 |
'status': 'error', |
|
|
174 |
'message': 'No description provided' |
|
|
175 |
}), 400 |
|
|
176 |
|
|
|
177 |
# Get prediction |
|
|
178 |
result = prediction_service.predict(data['description']) |
|
|
179 |
|
|
|
180 |
return jsonify(result) |
|
|
181 |
|
|
|
182 |
except ValueError as e: |
|
|
183 |
return jsonify({ |
|
|
184 |
'status': 'error', |
|
|
185 |
'message': str(e) |
|
|
186 |
}), 400 |
|
|
187 |
except Exception as e: |
|
|
188 |
error_traceback = traceback.format_exc() |
|
|
189 |
app.logger.error(f"Prediction failed: {str(e)}") |
|
|
190 |
app.logger.error(f"Traceback: {error_traceback}") |
|
|
191 |
return jsonify({ |
|
|
192 |
'status': 'error', |
|
|
193 |
'message': str(e), |
|
|
194 |
'traceback': error_traceback |
|
|
195 |
}), 500 |
|
|
196 |
|
|
|
197 |
|
|
|
198 |
if __name__ == "__main__": |
|
|
199 |
app.run(debug=True, host='0.0.0.0', port=5000) |