--- a
+++ b/app.py
@@ -0,0 +1,78 @@
+from flask import Flask, request, jsonify
+import torch
+from torchvision import transforms
+from PIL import Image
+import io
+import torch.nn as nn
+import timm
+
+# Define the model class
+class LeukemiaModel(nn.Module):
+    def __init__(self, num_classes=4):
+        super(LeukemiaModel, self).__init__()
+        
+        # Pretrained EfficientNetB7
+        self.base_model = timm.create_model('efficientnet_b7', pretrained=False, num_classes=0)
+        
+        # Custom layers
+        self.flatten = nn.Flatten()
+        self.dense1 = nn.Linear(2560 * 7 * 7, 256)  # Adjust input size if needed
+        self.relu = nn.ReLU()
+        self.dropout = nn.Dropout(p=0.3)
+        self.dense2 = nn.Linear(256, num_classes)  # Final output layer for classification
+        self.softmax = nn.Softmax(dim=1)  # Softmax for multi-class classification
+        
+    def forward(self, x):
+        # Forward pass through the base model (EfficientNetB7)
+        x = self.base_model(x)
+        
+        # Flatten and pass through custom layers
+        x = self.flatten(x)
+        x = self.dense1(x)
+        x = self.relu(x)
+        x = self.dropout(x)
+        x = self.dense2(x)
+        x = self.softmax(x)
+        
+        return x
+
+
+# Initialize Flask app
+app = Flask(__name__)
+
+# Load the pre-trained model
+model = LeukemiaModel(num_classes=4)
+model.load_state_dict(torch.load('leukemia_model.pth', weights_only=True))
+model.eval()  # Set the model to evaluation mode
+
+# Define the image transformation pipeline
+transform = transforms.Compose([
+    transforms.Resize((224, 224)),  # Resize image to match EfficientNet input size
+    transforms.ToTensor(),  # Convert image to tensor
+    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
+])
+
+@app.route('/predict', methods=['POST'])
+def predict():
+    if 'file' not in request.files:
+        return jsonify({'error': 'No file part'}), 400
+    
+    file = request.files['file']
+    
+    # Read and process the image
+    try:
+        img_bytes = file.read()
+        img = Image.open(io.BytesIO(img_bytes))
+        img = transform(img).unsqueeze(0)  # Add batch dimension
+    except Exception as e:
+        return jsonify({'error': str(e)}), 400
+    
+    # Inference
+    with torch.no_grad():
+        outputs = model(img)
+        _, predicted = torch.max(outputs, 1)  # Get the class with the highest probability
+    
+    return jsonify({'prediction': int(predicted.item())})
+
+if __name__ == '__main__':
+    app.run(debug=True)