Diff of /api/main.py [000000] .. [70ba26]

Switch to unified view

a b/api/main.py
1
from fastapi import FastAPI, File, UploadFile
2
from fastapi.middleware.cors import CORSMiddleware
3
import uvicorn
4
import numpy as np
5
from io import BytesIO
6
from PIL import Image
7
import tensorflow as tf
8
9
app = FastAPI()
10
# Define a list of allowed origins (modify as needed)
11
origins = [
12
    "http://localhost",
13
    "http://localhost:3000",
14
]
15
16
# Add CORS middleware to allow cross-origin requests
17
app.add_middleware(
18
    CORSMiddleware,
19
    allow_origins=origins,
20
    allow_credentials=True,
21
    allow_methods=["*"],
22
    allow_headers=["*"],
23
)
24
25
MODEL = tf.keras.models.load_model("../models/1")
26
CLASS_NAMES = ['Benign', '[Malignant] Pre-B',
27
               '[Malignant] Pro-B', '[Malignant] early Pre-B']
28
29
30
@app.get("/ping")
31
async def ping():
32
    return "Hello, I am alive"
33
34
35
def read_file_as_image(data) -> np.ndarray:
36
    image = np.array(Image.open(BytesIO(data)))
37
    return image
38
39
40
@app.post("/predict")
41
async def predict(
42
    file: UploadFile = File(...)
43
):
44
    image = read_file_as_image(await file.read())
45
    image = tf.image.resize(image, (264, 264))
46
    img_batch = np.expand_dims(image, 0)
47
48
    predictions = MODEL.predict(img_batch)
49
    predicted_class = CLASS_NAMES[np.argmax(predictions[0])]
50
    confidence = np.max(predictions[0])
51
    return {
52
        'class': predicted_class,
53
        'confidence': float(confidence)
54
    }
55
if __name__ == "__main__":
56
    uvicorn.run(app, host='localhost', port=8000)