--- a +++ b/Streamlit/app.py @@ -0,0 +1,173 @@ +import streamlit as st +import requests +import pandas as pd + +# Define the FastAPI endpoint +url = "http://20.253.116.228/predict" + +# Style for the result +st.markdown(""" + <style> + .big-font { + font-size:30px !important; + font-weight: bold; + } + </style> + """, unsafe_allow_html=True) + +# Function to translate prediction to human-readable form +def interpret_prediction(prediction): + if prediction == 0: + return "Insuffient Weight" + elif prediction == 1: + return "Healthy Weight" + elif prediction == 2: + return "Overweight Level 1" + elif prediction == 3: + return "Overweight Level 2" + elif prediction == 4: + return "Obesity Level 1" + elif prediction == 5: + return "Obesity Level 2" + elif prediction == 6: + return "Obesity Level 3" + else: + return "Unknown Level" + + +def main(): + st.title("Obesity Level Prediction") + + # User can choose to input data via form or upload CSV + input_method = st.radio("How would you like to input data?", ('Fill Form', 'Upload CSV')) + + if input_method == 'Fill Form': + # Initialize form fields with default values + form_data = { + "id": 0, + "Gender": 'Male', + "Age": 25.0, + "Height": 1.75, + "Weight": 70.0, + "family_history_with_overweight": 'yes', + "FAVC": 'yes', + "FCVC": 2.0, + "NCP": 3.0, + "CAEC": 'Sometimes', + "SMOKE": 'no', + "CH2O": 2.0, + "SCC": 'no', + "FAF": 1.0, + "TUE": 2.0, + "CALC": 'Sometimes', + "MTRANS": 'Public_Transportation' + } + + # Create form to input data or display data from CSV + with st.form(key='my_form'): + id_field = st.number_input('ID', value=form_data["id"]) + gender = st.selectbox('Gender', ['Male', 'Female'], index=['Male', 'Female'].index(form_data["Gender"])) + age = st.number_input('Age', min_value=18.00, max_value=100.00, value=form_data["Age"]) + height = st.number_input('Height (meters)', value=form_data["Height"]) + weight = st.number_input('Weight (kg)', value=form_data["Weight"]) + family_history_with_overweight = st.selectbox('Family history of overweight', ['yes', 'no'], index=['yes', 'no'].index(form_data["family_history_with_overweight"])) + favc = st.selectbox('Frequent consumption of high caloric food', ['yes', 'no'], index=['yes', 'no'].index(form_data["FAVC"])) + fcvc = st.number_input('Frequency of vegetables consumption', min_value=1.00, max_value=3.00, value=form_data["FCVC"]) + ncp = st.number_input('Number of main meals', value=form_data["NCP"]) + caec = st.selectbox('Consumption of food between meals', ['no', 'Sometimes', 'Frequently', 'Always'], index=['no', 'Sometimes', 'Frequently', 'Always'].index(form_data["CAEC"])) + smoke = st.selectbox('Smoking', ['yes', 'no'], index=['yes', 'no'].index(form_data["SMOKE"])) + ch2o = st.number_input('Consumption of water daily (liters)', value=form_data["CH2O"]) + scc = st.selectbox('Calories consumption monitoring', ['yes', 'no'], index=['yes', 'no'].index(form_data["SCC"])) + faf = st.number_input('Physical activity frequency (times per week)', value=form_data["FAF"]) + tue = st.number_input('Time using electronic devices (hours per day)', value=form_data["TUE"]) + calc = st.selectbox('Consumption of alcohol', ['no', 'Sometimes', 'Frequently'], index=['no', 'Sometimes', 'Frequently'].index(form_data["CALC"])) + mtrans = st.selectbox('Transportation used', ['Automobile', 'Motorbike', 'Bike', 'Public_Transportation', 'Walking'], index=['Automobile', 'Motorbike', 'Bike', 'Public_Transportation', 'Walking'].index(form_data["MTRANS"])) + submit_button = st.form_submit_button(label='Predict Obesity Level') + + # Handle form submission + if submit_button: + # Construct the request payload + data = { + "id": id_field, + "Gender": gender, + "Age": age, + "Height": height, + "Weight": weight, + "family_history_with_overweight": family_history_with_overweight, + "FAVC": favc, + "FCVC": fcvc, + "NCP": ncp, + "CAEC": caec, + "SMOKE": smoke, + "CH2O": ch2o, + "SCC": scc, + "FAF": faf, + "TUE": tue, + "CALC": calc, + "MTRANS": mtrans + } + + # Send a post request to the server + response = requests.post(url, json=data) + if response.status_code == 200: + result = response.json() + # Interpret the prediction for the user + prediction = interpret_prediction(result.get('prediction', -1)) + # Display the prediction result + st.markdown(f'<p class="big-font">Obesity Level: {prediction}</p>', unsafe_allow_html=True) + else: + st.error("Failed to get a valid response from the model.") + + elif input_method == 'Upload CSV': + uploaded_file = st.file_uploader("Choose a CSV file", type="csv") + if uploaded_file is not None: + dataframe = pd.read_csv(uploaded_file) + # Create a new column for predictions in the dataframe + dataframe['Obesity Level'] = None + + # Iterate over the rows of the dataframe and make predictions + with st.spinner('Making predictions...'): + for index, row in dataframe.iterrows(): + # Convert the row to dictionary + data = row.to_dict() + print(data) + # Remove the 'Obesity Level' key if present + data.pop('Obesity Level', None) + # Send a post request to the server + response = requests.post(url, json=data) + if response.status_code == 200: + result = response.json() + # Interpret the prediction for the user + prediction = interpret_prediction(result.get('prediction', -1)) + # Update the dataframe with predictions + dataframe.at[index, 'Obesity Level'] = prediction + else: + st.error(f"Failed to get a valid response from the model for row {index+1}: {response.text}") + break # Stop the loop if there is an error + + # Only proceed if all predictions were successful + if not dataframe['Obesity Level'].isnull().any(): + st.success('All predictions made successfully!') + # Display the dataframe with predictions + st.dataframe(dataframe) + # Allow the user to download the augmented CSV + st.download_button( + label="Download CSV with predictions", + data=dataframe.to_csv(index=False), + file_name='predictions.csv', + mime='text/csv' + ) + else: + st.error(f"Failed to get a valid response from the model: {response.text}") + +if __name__ == "__main__": + main() + + + + + + + + +