Diff of /app.py [000000] .. [d366d1]

Switch to unified view

a b/app.py
1
from flask import Flask, request, jsonify
2
import torch
3
from torchvision import transforms
4
from PIL import Image
5
import io
6
import torch.nn as nn
7
import timm
8
9
# Define the model class
10
class LeukemiaModel(nn.Module):
11
    def __init__(self, num_classes=4):
12
        super(LeukemiaModel, self).__init__()
13
        
14
        # Pretrained EfficientNetB7
15
        self.base_model = timm.create_model('efficientnet_b7', pretrained=False, num_classes=0)
16
        
17
        # Custom layers
18
        self.flatten = nn.Flatten()
19
        self.dense1 = nn.Linear(2560 * 7 * 7, 256)  # Adjust input size if needed
20
        self.relu = nn.ReLU()
21
        self.dropout = nn.Dropout(p=0.3)
22
        self.dense2 = nn.Linear(256, num_classes)  # Final output layer for classification
23
        self.softmax = nn.Softmax(dim=1)  # Softmax for multi-class classification
24
        
25
    def forward(self, x):
26
        # Forward pass through the base model (EfficientNetB7)
27
        x = self.base_model(x)
28
        
29
        # Flatten and pass through custom layers
30
        x = self.flatten(x)
31
        x = self.dense1(x)
32
        x = self.relu(x)
33
        x = self.dropout(x)
34
        x = self.dense2(x)
35
        x = self.softmax(x)
36
        
37
        return x
38
39
40
# Initialize Flask app
41
app = Flask(__name__)
42
43
# Load the pre-trained model
44
model = LeukemiaModel(num_classes=4)
45
model.load_state_dict(torch.load('leukemia_model.pth', weights_only=True))
46
model.eval()  # Set the model to evaluation mode
47
48
# Define the image transformation pipeline
49
transform = transforms.Compose([
50
    transforms.Resize((224, 224)),  # Resize image to match EfficientNet input size
51
    transforms.ToTensor(),  # Convert image to tensor
52
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
53
])
54
55
@app.route('/predict', methods=['POST'])
56
def predict():
57
    if 'file' not in request.files:
58
        return jsonify({'error': 'No file part'}), 400
59
    
60
    file = request.files['file']
61
    
62
    # Read and process the image
63
    try:
64
        img_bytes = file.read()
65
        img = Image.open(io.BytesIO(img_bytes))
66
        img = transform(img).unsqueeze(0)  # Add batch dimension
67
    except Exception as e:
68
        return jsonify({'error': str(e)}), 400
69
    
70
    # Inference
71
    with torch.no_grad():
72
        outputs = model(img)
73
        _, predicted = torch.max(outputs, 1)  # Get the class with the highest probability
74
    
75
    return jsonify({'prediction': int(predicted.item())})
76
77
if __name__ == '__main__':
78
    app.run(debug=True)