--- a +++ b/api/main.py @@ -0,0 +1,56 @@ +from fastapi import FastAPI, File, UploadFile +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import numpy as np +from io import BytesIO +from PIL import Image +import tensorflow as tf + +app = FastAPI() +# Define a list of allowed origins (modify as needed) +origins = [ + "http://localhost", + "http://localhost:3000", +] + +# Add CORS middleware to allow cross-origin requests +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +MODEL = tf.keras.models.load_model("../models/1") +CLASS_NAMES = ['Benign', '[Malignant] Pre-B', + '[Malignant] Pro-B', '[Malignant] early Pre-B'] + + +@app.get("/ping") +async def ping(): + return "Hello, I am alive" + + +def read_file_as_image(data) -> np.ndarray: + image = np.array(Image.open(BytesIO(data))) + return image + + +@app.post("/predict") +async def predict( + file: UploadFile = File(...) +): + image = read_file_as_image(await file.read()) + image = tf.image.resize(image, (264, 264)) + img_batch = np.expand_dims(image, 0) + + predictions = MODEL.predict(img_batch) + predicted_class = CLASS_NAMES[np.argmax(predictions[0])] + confidence = np.max(predictions[0]) + return { + 'class': predicted_class, + 'confidence': float(confidence) + } +if __name__ == "__main__": + uvicorn.run(app, host='localhost', port=8000)